Initial commit: hello-agent — headless RustDesk-protocol-compatible Windows agent
build-windows / build-hello-agent-x64 (push) Successful in 5m41s

A single-binary, Flutter-free remote-support agent that speaks the stock
RustDesk wire protocol. Designed for one-line MDM deployment against a
self-hosted rustdesk-server: a supporter using the unmodified rustdesk.exe
client connects, the controlled-side user gets a native Win32 approval
prompt, click Yes / No.

CLI surface

    hello-agent.exe --install                # register + start service
    hello-agent.exe --uninstall              # stop, delete, clean up
    hello-agent.exe --config <BLOB>          # admin-UI deploy string
    hello-agent.exe --install --config <BLOB>   # MDM one-liner

--config accepts both forms emitted by the rustdesk-server admin UI: the
reversed-base64 deploy string and the host=,key=,api=,relay= filename
form. Decoded via the upstream custom_server module, persisted via
hbb_common::config::Config::set_option.

Architecture

    --service runs as a Session 0 LocalSystem service. It polls
    WTSGetActiveConsoleSessionId and (re)spawns hello-agent.exe --server
    into the active console session via librustdesk::platform::run_as_user,
    handling the Session 0 → user-session token impersonation.

    --server is the worker. It boots three concurrent components:
      1. cm_popup: an IPC listener on the rustdesk `_cm` named pipe
      2. librustdesk::start_server(true, false): the upstream protocol
         stack — rendezvous mediator, NAT punch, IPC server, screen
         capture, login validation, hbbs_http heartbeat / sysinfo sync
      3. (implicit) ApproveMode::Click is pinned in config, so every
         incoming connection routes through cm_popup

The popup mechanism reuses an existing upstream contract without any
patches to the protocol code: when a peer connects with no password,
Connection::start in the upstream code calls try_start_cm_ipc, which
ipc::connect-s the `_cm` pipe before falling back to spawning a Flutter
CM child. Since cm_popup is up first, step 1 succeeds; we read the
Data::Login{authorized:false} frame, show MessageBoxTimeoutW (Yes/No,
60s, top-most, system-modal), and reply Data::Authorize or Data::Close.

Source tree

    src/main.rs             CLI dispatcher + run_server() composition
    src/cli.rs              hand-rolled argv parser + unit tests
    src/service.rs          windows-service install/uninstall/dispatcher
    src/config_import.rs    --config blob decoding + persistence
    src/cm_popup.rs         _cm IPC listener + Win32 approval dialog

Vendoring

The upstream RustDesk crate is vendored under vendor/rustdesk/ — full
workspace including libs/{hbb_common, scrap, enigo, clipboard,
virtual_display, remote_printer}. This makes the build self-contained
(no submodules, no sibling-repo checkout in CI) and gives us freedom to
fork in a different direction later. Excluded from the vendor: .git,
target/, flutter/, appimage/, flatpak/, fastlane/, docs/, examples/,
ci/, build.py, Dockerfile, upstream README/CLAUDE/AGENTS/GEMINI.

One local divergence vs. upstream: vendor/rustdesk/src/lib.rs flips
`mod custom_server` → `pub mod custom_server` so config_import.rs can
call get_custom_server_from_string without going through the
ui_interface shim. Documented in README.md → "Re-syncing the vendored
copy".

CI

.gitea/workflows/build-windows.yml builds on a self-hosted Windows
runner with Rust 1.75, LLVM 15.0.6 (libclang for bindgen via libvpx-sys),
and a vcpkg cache. The vendored vcpkg.json drives x64-windows-static
deps. The workflow stages the resulting hello-agent.exe into
SignOutput\, reports authenticode signing status (warns on unsigned),
and uploads as artifact. ~15 min full build, faster on incremental.

Out of scope for this commit: Linux/macOS builds, code signing, MSI
packaging, coexistence with stock rustdesk on the same box (currently
shares the RustDesk APP_NAME and config dir).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-05-07 11:01:30 +02:00
commit f8ead215d8
479 changed files with 188052 additions and 0 deletions
+16
View File
@@ -0,0 +1,16 @@
[target.x86_64-pc-windows-msvc]
rustflags = ["-Ctarget-feature=+crt-static"]
[target.i686-pc-windows-msvc]
rustflags = ["-C", "target-feature=+crt-static", "-C", "link-args=/NODEFAULTLIB:MSVCRT"]
[target.'cfg(target_os="macos")']
rustflags = [
"-C", "link-args=-sectcreate __CGPreLoginApp __cgpreloginapp /dev/null",
]
#[target.'cfg(target_os="linux")']
# glibc-static required, this may fix https://github.com/rustdesk/rustdesk/issues/9103, but I do not want this big change
# this is unlikely to help also, because the other so files still use libc dynamically
#rustflags = [
# "-C", "link-args=-Wl,-Bstatic -lc -Wl,-Bdynamic"
#]
[net]
git-fetch-with-cli = true
+1
View File
@@ -0,0 +1 @@
* text=auto
+258
View File
@@ -0,0 +1,258 @@
[package]
name = "rustdesk"
version = "1.4.6"
authors = ["rustdesk <info@rustdesk.com>"]
edition = "2021"
build= "build.rs"
description = "RustDesk Remote Desktop"
default-run = "rustdesk"
rust-version = "1.75"
[lib]
name = "librustdesk"
# Local divergence vs upstream rustdesk: ["cdylib", "staticlib", "rlib"].
# hello-agent statically links the rlib into hello-agent.exe and never
# needs the cdylib (used by upstream for Flutter FFI) or the staticlib.
# Cargo builds all crate-types of a [lib] together, and the cdylib link
# step aggregates multiple windows-targets/windows_x86_64_msvc versions
# into one DLL alongside the explicitly-linked windows.lib import library,
# producing LNK1169 multiply-defined-symbol failures. Restricting to rlib
# skips the cdylib link entirely and is fine for our consumer.
crate-type = ["rlib"]
[[bin]]
name = "naming"
path = "src/naming.rs"
[[bin]]
name = "service"
path = "src/service.rs"
[features]
inline = []
cli = []
use_samplerate = ["samplerate"]
use_rubato = ["rubato"]
use_dasp = ["dasp"]
flutter = ["flutter_rust_bridge"]
default = ["use_dasp"]
hwcodec = ["scrap/hwcodec"]
vram = ["scrap/vram"]
mediacodec = ["scrap/mediacodec"]
plugin_framework = []
linux-pkg-config = ["magnum-opus/linux-pkg-config", "scrap/linux-pkg-config"]
unix-file-copy-paste = [
"dep:x11-clipboard",
"dep:x11rb",
"dep:percent-encoding",
"dep:once_cell",
"clipboard/unix-file-copy-paste",
]
screencapturekit = ["cpal/screencapturekit"]
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
async-trait = "0.1"
scrap = { path = "libs/scrap", features = ["wayland"] }
hbb_common = { path = "libs/hbb_common" }
serde_derive = "1.0"
serde = "1.0"
serde_json = "1.0"
serde_repr = "0.1"
cfg-if = "1.0"
lazy_static = "1.4"
sha2 = "0.10"
repng = "0.2"
parity-tokio-ipc = { git = "https://github.com/rustdesk-org/parity-tokio-ipc" }
magnum-opus = { git = "https://github.com/rustdesk-org/magnum-opus" }
dasp = { version = "0.11", features = ["signal", "interpolate-linear", "interpolate"], optional = true }
rubato = { version = "0.12", optional = true }
samplerate = { version = "0.2", optional = true }
uuid = { version = "1.3", features = ["v4"] }
clap = "4.2"
rpassword = "7.2"
num_cpus = "1.15"
bytes = { version = "1.4", features = ["serde"] }
default-net = "0.14"
wol-rs = "1.0"
flutter_rust_bridge = { version = "=1.80", features = ["uuid"], optional = true}
errno = "0.3"
rdev = { git = "https://github.com/rustdesk-org/rdev" }
url = { version = "2.3", features = ["serde"] }
crossbeam-queue = "0.3"
hex = "0.4"
chrono = "0.4"
cidr-utils = "0.5"
fon = "0.6"
zip = "0.6"
shutdown_hooks = "0.1"
totp-rs = { version = "5.4", default-features = false, features = ["gen_secret", "otpauth"] }
stunclient = "0.4"
kcp-sys= { git = "https://github.com/rustdesk-org/kcp-sys"}
reqwest = { version = "0.12", features = ["blocking", "socks", "json", "native-tls", "rustls-tls", "rustls-tls-native-roots", "gzip"], default-features=false }
[target.'cfg(not(target_os = "linux"))'.dependencies]
# https://github.com/rustdesk/rustdesk/discussions/10197, not use cpal on linux
cpal = { git = "https://github.com/rustdesk-org/cpal", branch = "osx-screencapturekit" }
ringbuf = "0.3"
[target.'cfg(not(any(target_os = "android", target_os = "ios")))'.dependencies]
mac_address = "1.1"
sciter-rs = { git = "https://github.com/rustdesk-org/rust-sciter", branch = "dyn" }
sys-locale = "0.3"
enigo = { path = "libs/enigo", features = [ "with_serde" ] }
clipboard = { path = "libs/clipboard" }
ctrlc = "3.2"
# arboard = { version = "3.4", features = ["wayland-data-control"] }
arboard = { git = "https://github.com/rustdesk-org/arboard", features = ["wayland-data-control"] }
clipboard-master = { git = "https://github.com/rustdesk-org/clipboard-master" }
portable-pty = { git = "https://github.com/rustdesk-org/wezterm", branch = "rustdesk/pty_based_0.8.1", package = "portable-pty" }
system_shutdown = "4.0"
qrcode-generator = "4.1"
[target.'cfg(target_os = "windows")'.dependencies]
winapi = { version = "0.3", features = [
"winuser",
"wincrypt",
"shellscalingapi",
"pdh",
"synchapi",
"memoryapi",
"shellapi",
"devguid",
"setupapi",
"cguid",
"cfgmgr32",
"ioapiset",
"winspool",
] }
windows = { version = "0.61", features = [
"Win32",
"Win32_Foundation",
"Win32_Security",
"Win32_Security_Authorization",
"Win32_Storage_FileSystem",
"Win32_System",
"Win32_System_Diagnostics",
"Win32_System_Diagnostics_ToolHelp",
"Win32_System_Environment",
"Win32_System_IO",
"Win32_System_Memory",
"Win32_System_Pipes",
"Win32_System_Threading",
"Win32_UI_Shell",
] }
winreg = "0.11"
windows-service = "0.6"
virtual_display = { path = "libs/virtual_display" }
remote_printer = { path = "libs/remote_printer" }
impersonate_system = { git = "https://github.com/rustdesk-org/impersonate-system" }
shared_memory = "0.12"
tauri-winrt-notification = "0.1"
runas = "1.2"
[target.'cfg(target_os = "macos")'.dependencies]
objc = "0.2"
cocoa = "0.24"
dispatch = "0.2"
core-foundation = "0.9"
core-graphics = "0.22"
include_dir = "0.7"
fruitbasket = "0.10"
objc_id = "0.1"
# If we use piet "0.7" here, we must also update core-graphics to "0.24".
piet = "0.6"
piet-coregraphics = "0.6"
foreign-types = "0.3"
[target.'cfg(any(target_os = "macos", target_os = "linux", target_os = "windows"))'.dependencies]
tray-icon = { git = "https://github.com/tauri-apps/tray-icon", version = "0.21.3" }
tao = { git = "https://github.com/rustdesk-org/tao", branch = "dev" }
image = "0.24"
[target.'cfg(any(target_os = "macos", target_os = "linux"))'.dependencies]
keepawake = { git = "https://github.com/rustdesk-org/keepawake-rs" }
[target.'cfg(any(target_os = "windows", target_os = "linux"))'.dependencies]
wallpaper = { git = "https://github.com/rustdesk-org/wallpaper.rs" }
tiny-skia = "0.11"
softbuffer = "0.4"
fontdb = "0.23"
bytemuck = "1.23"
ttf-parser = "0.25"
[target.'cfg(target_os = "linux")'.dependencies]
libxdo-sys = "0.11"
psimple = { package = "libpulse-simple-binding", version = "2.27" }
pulse = { package = "libpulse-binding", version = "2.27" }
rust-pulsectl = { git = "https://github.com/rustdesk-org/pulsectl" }
async-process = "1.7"
evdev = { git="https://github.com/rustdesk-org/evdev" }
dbus = "0.9"
dbus-crossroads = "0.5"
pam = { git="https://github.com/rustdesk-org/pam" }
x11-clipboard = {git="https://github.com/clslaid/x11-clipboard", branch = "feat/store-batch", optional = true}
x11rb = {version = "0.12", features = ["all-extensions"], optional = true}
percent-encoding = {version = "2.3", optional = true}
once_cell = {version = "1.18", optional = true}
nix = { version = "0.29", features = ["term", "process"]}
gtk = "0.18"
termios = "0.3"
terminfo = "0.8"
winit = "0.30"
[target.'cfg(any(target_os = "linux", target_os = "android"))'.dependencies]
openssl = { version = "0.10", features = ["vendored"] }
[target.'cfg(target_os = "android")'.dependencies]
android_logger = "0.13"
jni = "0.21"
android-wakelock = { git = "https://github.com/rustdesk-org/android-wakelock" }
[workspace]
members = ["libs/scrap", "libs/hbb_common", "libs/enigo", "libs/clipboard", "libs/virtual_display", "libs/virtual_display/dylib", "libs/portable", "libs/remote_printer"]
exclude = ["vdi/host", "examples/custom_plugin"]
# Patch libxdo-sys to use a stub implementation that doesn't require libxdo
# This allows building and running on systems without libxdo installed (e.g., Wayland-only)
[patch.crates-io]
libxdo-sys = { path = "libs/libxdo-sys-stub" }
[package.metadata.winres]
LegalCopyright = "Copyright © 2025 cStudio GmbH. All rights reserved."
ProductName = "RustDesk"
FileDescription = "RustDesk Remote Desktop"
OriginalFilename = "rustdesk.exe"
[target.'cfg(target_os="windows")'.build-dependencies]
winres = "0.1"
winapi = { version = "0.3", features = [ "winnt", "pdh", "synchapi" ] }
[build-dependencies]
cc = "1.0"
hbb_common = { path = "libs/hbb_common" }
os-version = "0.2"
[dev-dependencies]
hound = "3.5"
docopt = "1.1"
[package.metadata.bundle]
name = "RustDesk"
identifier = "com.carriez.rustdesk"
icon = ["res/32x32.png", "res/128x128.png", "res/128x128@2x.png"]
osx_minimum_system_version = "10.14"
#https://github.com/johnthagen/min-sized-rust
[profile.release]
lto = true
codegen-units = 1
panic = 'abort'
strip = true
#opt-level = 'z' # only have smaller size after strip
rpath = true
[profile.dev]
debug = 1
+661
View File
@@ -0,0 +1,661 @@
GNU AFFERO GENERAL PUBLIC LICENSE
Version 3, 19 November 2007
Copyright (C) 2007 Free Software Foundation, Inc. <https://fsf.org/>
Everyone is permitted to copy and distribute verbatim copies
of this license document, but changing it is not allowed.
Preamble
The GNU Affero General Public License is a free, copyleft license for
software and other kinds of works, specifically designed to ensure
cooperation with the community in the case of network server software.
The licenses for most software and other practical works are designed
to take away your freedom to share and change the works. By contrast,
our General Public Licenses are intended to guarantee your freedom to
share and change all versions of a program--to make sure it remains free
software for all its users.
When we speak of free software, we are referring to freedom, not
price. Our General Public Licenses are designed to make sure that you
have the freedom to distribute copies of free software (and charge for
them if you wish), that you receive source code or can get it if you
want it, that you can change the software or use pieces of it in new
free programs, and that you know you can do these things.
Developers that use our General Public Licenses protect your rights
with two steps: (1) assert copyright on the software, and (2) offer
you this License which gives you legal permission to copy, distribute
and/or modify the software.
A secondary benefit of defending all users' freedom is that
improvements made in alternate versions of the program, if they
receive widespread use, become available for other developers to
incorporate. Many developers of free software are heartened and
encouraged by the resulting cooperation. However, in the case of
software used on network servers, this result may fail to come about.
The GNU General Public License permits making a modified version and
letting the public access it on a server without ever releasing its
source code to the public.
The GNU Affero General Public License is designed specifically to
ensure that, in such cases, the modified source code becomes available
to the community. It requires the operator of a network server to
provide the source code of the modified version running there to the
users of that server. Therefore, public use of a modified version, on
a publicly accessible server, gives the public access to the source
code of the modified version.
An older license, called the Affero General Public License and
published by Affero, was designed to accomplish similar goals. This is
a different license, not a version of the Affero GPL, but Affero has
released a new version of the Affero GPL which permits relicensing under
this license.
The precise terms and conditions for copying, distribution and
modification follow.
TERMS AND CONDITIONS
0. Definitions.
"This License" refers to version 3 of the GNU Affero General Public License.
"Copyright" also means copyright-like laws that apply to other kinds of
works, such as semiconductor masks.
"The Program" refers to any copyrightable work licensed under this
License. Each licensee is addressed as "you". "Licensees" and
"recipients" may be individuals or organizations.
To "modify" a work means to copy from or adapt all or part of the work
in a fashion requiring copyright permission, other than the making of an
exact copy. The resulting work is called a "modified version" of the
earlier work or a work "based on" the earlier work.
A "covered work" means either the unmodified Program or a work based
on the Program.
To "propagate" a work means to do anything with it that, without
permission, would make you directly or secondarily liable for
infringement under applicable copyright law, except executing it on a
computer or modifying a private copy. Propagation includes copying,
distribution (with or without modification), making available to the
public, and in some countries other activities as well.
To "convey" a work means any kind of propagation that enables other
parties to make or receive copies. Mere interaction with a user through
a computer network, with no transfer of a copy, is not conveying.
An interactive user interface displays "Appropriate Legal Notices"
to the extent that it includes a convenient and prominently visible
feature that (1) displays an appropriate copyright notice, and (2)
tells the user that there is no warranty for the work (except to the
extent that warranties are provided), that licensees may convey the
work under this License, and how to view a copy of this License. If
the interface presents a list of user commands or options, such as a
menu, a prominent item in the list meets this criterion.
1. Source Code.
The "source code" for a work means the preferred form of the work
for making modifications to it. "Object code" means any non-source
form of a work.
A "Standard Interface" means an interface that either is an official
standard defined by a recognized standards body, or, in the case of
interfaces specified for a particular programming language, one that
is widely used among developers working in that language.
The "System Libraries" of an executable work include anything, other
than the work as a whole, that (a) is included in the normal form of
packaging a Major Component, but which is not part of that Major
Component, and (b) serves only to enable use of the work with that
Major Component, or to implement a Standard Interface for which an
implementation is available to the public in source code form. A
"Major Component", in this context, means a major essential component
(kernel, window system, and so on) of the specific operating system
(if any) on which the executable work runs, or a compiler used to
produce the work, or an object code interpreter used to run it.
The "Corresponding Source" for a work in object code form means all
the source code needed to generate, install, and (for an executable
work) run the object code and to modify the work, including scripts to
control those activities. However, it does not include the work's
System Libraries, or general-purpose tools or generally available free
programs which are used unmodified in performing those activities but
which are not part of the work. For example, Corresponding Source
includes interface definition files associated with source files for
the work, and the source code for shared libraries and dynamically
linked subprograms that the work is specifically designed to require,
such as by intimate data communication or control flow between those
subprograms and other parts of the work.
The Corresponding Source need not include anything that users
can regenerate automatically from other parts of the Corresponding
Source.
The Corresponding Source for a work in source code form is that
same work.
2. Basic Permissions.
All rights granted under this License are granted for the term of
copyright on the Program, and are irrevocable provided the stated
conditions are met. This License explicitly affirms your unlimited
permission to run the unmodified Program. The output from running a
covered work is covered by this License only if the output, given its
content, constitutes a covered work. This License acknowledges your
rights of fair use or other equivalent, as provided by copyright law.
You may make, run and propagate covered works that you do not
convey, without conditions so long as your license otherwise remains
in force. You may convey covered works to others for the sole purpose
of having them make modifications exclusively for you, or provide you
with facilities for running those works, provided that you comply with
the terms of this License in conveying all material for which you do
not control copyright. Those thus making or running the covered works
for you must do so exclusively on your behalf, under your direction
and control, on terms that prohibit them from making any copies of
your copyrighted material outside their relationship with you.
Conveying under any other circumstances is permitted solely under
the conditions stated below. Sublicensing is not allowed; section 10
makes it unnecessary.
3. Protecting Users' Legal Rights From Anti-Circumvention Law.
No covered work shall be deemed part of an effective technological
measure under any applicable law fulfilling obligations under article
11 of the WIPO copyright treaty adopted on 20 December 1996, or
similar laws prohibiting or restricting circumvention of such
measures.
When you convey a covered work, you waive any legal power to forbid
circumvention of technological measures to the extent such circumvention
is effected by exercising rights under this License with respect to
the covered work, and you disclaim any intention to limit operation or
modification of the work as a means of enforcing, against the work's
users, your or third parties' legal rights to forbid circumvention of
technological measures.
4. Conveying Verbatim Copies.
You may convey verbatim copies of the Program's source code as you
receive it, in any medium, provided that you conspicuously and
appropriately publish on each copy an appropriate copyright notice;
keep intact all notices stating that this License and any
non-permissive terms added in accord with section 7 apply to the code;
keep intact all notices of the absence of any warranty; and give all
recipients a copy of this License along with the Program.
You may charge any price or no price for each copy that you convey,
and you may offer support or warranty protection for a fee.
5. Conveying Modified Source Versions.
You may convey a work based on the Program, or the modifications to
produce it from the Program, in the form of source code under the
terms of section 4, provided that you also meet all of these conditions:
a) The work must carry prominent notices stating that you modified
it, and giving a relevant date.
b) The work must carry prominent notices stating that it is
released under this License and any conditions added under section
7. This requirement modifies the requirement in section 4 to
"keep intact all notices".
c) You must license the entire work, as a whole, under this
License to anyone who comes into possession of a copy. This
License will therefore apply, along with any applicable section 7
additional terms, to the whole of the work, and all its parts,
regardless of how they are packaged. This License gives no
permission to license the work in any other way, but it does not
invalidate such permission if you have separately received it.
d) If the work has interactive user interfaces, each must display
Appropriate Legal Notices; however, if the Program has interactive
interfaces that do not display Appropriate Legal Notices, your
work need not make them do so.
A compilation of a covered work with other separate and independent
works, which are not by their nature extensions of the covered work,
and which are not combined with it such as to form a larger program,
in or on a volume of a storage or distribution medium, is called an
"aggregate" if the compilation and its resulting copyright are not
used to limit the access or legal rights of the compilation's users
beyond what the individual works permit. Inclusion of a covered work
in an aggregate does not cause this License to apply to the other
parts of the aggregate.
6. Conveying Non-Source Forms.
You may convey a covered work in object code form under the terms
of sections 4 and 5, provided that you also convey the
machine-readable Corresponding Source under the terms of this License,
in one of these ways:
a) Convey the object code in, or embodied in, a physical product
(including a physical distribution medium), accompanied by the
Corresponding Source fixed on a durable physical medium
customarily used for software interchange.
b) Convey the object code in, or embodied in, a physical product
(including a physical distribution medium), accompanied by a
written offer, valid for at least three years and valid for as
long as you offer spare parts or customer support for that product
model, to give anyone who possesses the object code either (1) a
copy of the Corresponding Source for all the software in the
product that is covered by this License, on a durable physical
medium customarily used for software interchange, for a price no
more than your reasonable cost of physically performing this
conveying of source, or (2) access to copy the
Corresponding Source from a network server at no charge.
c) Convey individual copies of the object code with a copy of the
written offer to provide the Corresponding Source. This
alternative is allowed only occasionally and noncommercially, and
only if you received the object code with such an offer, in accord
with subsection 6b.
d) Convey the object code by offering access from a designated
place (gratis or for a charge), and offer equivalent access to the
Corresponding Source in the same way through the same place at no
further charge. You need not require recipients to copy the
Corresponding Source along with the object code. If the place to
copy the object code is a network server, the Corresponding Source
may be on a different server (operated by you or a third party)
that supports equivalent copying facilities, provided you maintain
clear directions next to the object code saying where to find the
Corresponding Source. Regardless of what server hosts the
Corresponding Source, you remain obligated to ensure that it is
available for as long as needed to satisfy these requirements.
e) Convey the object code using peer-to-peer transmission, provided
you inform other peers where the object code and Corresponding
Source of the work are being offered to the general public at no
charge under subsection 6d.
A separable portion of the object code, whose source code is excluded
from the Corresponding Source as a System Library, need not be
included in conveying the object code work.
A "User Product" is either (1) a "consumer product", which means any
tangible personal property which is normally used for personal, family,
or household purposes, or (2) anything designed or sold for incorporation
into a dwelling. In determining whether a product is a consumer product,
doubtful cases shall be resolved in favor of coverage. For a particular
product received by a particular user, "normally used" refers to a
typical or common use of that class of product, regardless of the status
of the particular user or of the way in which the particular user
actually uses, or expects or is expected to use, the product. A product
is a consumer product regardless of whether the product has substantial
commercial, industrial or non-consumer uses, unless such uses represent
the only significant mode of use of the product.
"Installation Information" for a User Product means any methods,
procedures, authorization keys, or other information required to install
and execute modified versions of a covered work in that User Product from
a modified version of its Corresponding Source. The information must
suffice to ensure that the continued functioning of the modified object
code is in no case prevented or interfered with solely because
modification has been made.
If you convey an object code work under this section in, or with, or
specifically for use in, a User Product, and the conveying occurs as
part of a transaction in which the right of possession and use of the
User Product is transferred to the recipient in perpetuity or for a
fixed term (regardless of how the transaction is characterized), the
Corresponding Source conveyed under this section must be accompanied
by the Installation Information. But this requirement does not apply
if neither you nor any third party retains the ability to install
modified object code on the User Product (for example, the work has
been installed in ROM).
The requirement to provide Installation Information does not include a
requirement to continue to provide support service, warranty, or updates
for a work that has been modified or installed by the recipient, or for
the User Product in which it has been modified or installed. Access to a
network may be denied when the modification itself materially and
adversely affects the operation of the network or violates the rules and
protocols for communication across the network.
Corresponding Source conveyed, and Installation Information provided,
in accord with this section must be in a format that is publicly
documented (and with an implementation available to the public in
source code form), and must require no special password or key for
unpacking, reading or copying.
7. Additional Terms.
"Additional permissions" are terms that supplement the terms of this
License by making exceptions from one or more of its conditions.
Additional permissions that are applicable to the entire Program shall
be treated as though they were included in this License, to the extent
that they are valid under applicable law. If additional permissions
apply only to part of the Program, that part may be used separately
under those permissions, but the entire Program remains governed by
this License without regard to the additional permissions.
When you convey a copy of a covered work, you may at your option
remove any additional permissions from that copy, or from any part of
it. (Additional permissions may be written to require their own
removal in certain cases when you modify the work.) You may place
additional permissions on material, added by you to a covered work,
for which you have or can give appropriate copyright permission.
Notwithstanding any other provision of this License, for material you
add to a covered work, you may (if authorized by the copyright holders of
that material) supplement the terms of this License with terms:
a) Disclaiming warranty or limiting liability differently from the
terms of sections 15 and 16 of this License; or
b) Requiring preservation of specified reasonable legal notices or
author attributions in that material or in the Appropriate Legal
Notices displayed by works containing it; or
c) Prohibiting misrepresentation of the origin of that material, or
requiring that modified versions of such material be marked in
reasonable ways as different from the original version; or
d) Limiting the use for publicity purposes of names of licensors or
authors of the material; or
e) Declining to grant rights under trademark law for use of some
trade names, trademarks, or service marks; or
f) Requiring indemnification of licensors and authors of that
material by anyone who conveys the material (or modified versions of
it) with contractual assumptions of liability to the recipient, for
any liability that these contractual assumptions directly impose on
those licensors and authors.
All other non-permissive additional terms are considered "further
restrictions" within the meaning of section 10. If the Program as you
received it, or any part of it, contains a notice stating that it is
governed by this License along with a term that is a further
restriction, you may remove that term. If a license document contains
a further restriction but permits relicensing or conveying under this
License, you may add to a covered work material governed by the terms
of that license document, provided that the further restriction does
not survive such relicensing or conveying.
If you add terms to a covered work in accord with this section, you
must place, in the relevant source files, a statement of the
additional terms that apply to those files, or a notice indicating
where to find the applicable terms.
Additional terms, permissive or non-permissive, may be stated in the
form of a separately written license, or stated as exceptions;
the above requirements apply either way.
8. Termination.
You may not propagate or modify a covered work except as expressly
provided under this License. Any attempt otherwise to propagate or
modify it is void, and will automatically terminate your rights under
this License (including any patent licenses granted under the third
paragraph of section 11).
However, if you cease all violation of this License, then your
license from a particular copyright holder is reinstated (a)
provisionally, unless and until the copyright holder explicitly and
finally terminates your license, and (b) permanently, if the copyright
holder fails to notify you of the violation by some reasonable means
prior to 60 days after the cessation.
Moreover, your license from a particular copyright holder is
reinstated permanently if the copyright holder notifies you of the
violation by some reasonable means, this is the first time you have
received notice of violation of this License (for any work) from that
copyright holder, and you cure the violation prior to 30 days after
your receipt of the notice.
Termination of your rights under this section does not terminate the
licenses of parties who have received copies or rights from you under
this License. If your rights have been terminated and not permanently
reinstated, you do not qualify to receive new licenses for the same
material under section 10.
9. Acceptance Not Required for Having Copies.
You are not required to accept this License in order to receive or
run a copy of the Program. Ancillary propagation of a covered work
occurring solely as a consequence of using peer-to-peer transmission
to receive a copy likewise does not require acceptance. However,
nothing other than this License grants you permission to propagate or
modify any covered work. These actions infringe copyright if you do
not accept this License. Therefore, by modifying or propagating a
covered work, you indicate your acceptance of this License to do so.
10. Automatic Licensing of Downstream Recipients.
Each time you convey a covered work, the recipient automatically
receives a license from the original licensors, to run, modify and
propagate that work, subject to this License. You are not responsible
for enforcing compliance by third parties with this License.
An "entity transaction" is a transaction transferring control of an
organization, or substantially all assets of one, or subdividing an
organization, or merging organizations. If propagation of a covered
work results from an entity transaction, each party to that
transaction who receives a copy of the work also receives whatever
licenses to the work the party's predecessor in interest had or could
give under the previous paragraph, plus a right to possession of the
Corresponding Source of the work from the predecessor in interest, if
the predecessor has it or can get it with reasonable efforts.
You may not impose any further restrictions on the exercise of the
rights granted or affirmed under this License. For example, you may
not impose a license fee, royalty, or other charge for exercise of
rights granted under this License, and you may not initiate litigation
(including a cross-claim or counterclaim in a lawsuit) alleging that
any patent claim is infringed by making, using, selling, offering for
sale, or importing the Program or any portion of it.
11. Patents.
A "contributor" is a copyright holder who authorizes use under this
License of the Program or a work on which the Program is based. The
work thus licensed is called the contributor's "contributor version".
A contributor's "essential patent claims" are all patent claims
owned or controlled by the contributor, whether already acquired or
hereafter acquired, that would be infringed by some manner, permitted
by this License, of making, using, or selling its contributor version,
but do not include claims that would be infringed only as a
consequence of further modification of the contributor version. For
purposes of this definition, "control" includes the right to grant
patent sublicenses in a manner consistent with the requirements of
this License.
Each contributor grants you a non-exclusive, worldwide, royalty-free
patent license under the contributor's essential patent claims, to
make, use, sell, offer for sale, import and otherwise run, modify and
propagate the contents of its contributor version.
In the following three paragraphs, a "patent license" is any express
agreement or commitment, however denominated, not to enforce a patent
(such as an express permission to practice a patent or covenant not to
sue for patent infringement). To "grant" such a patent license to a
party means to make such an agreement or commitment not to enforce a
patent against the party.
If you convey a covered work, knowingly relying on a patent license,
and the Corresponding Source of the work is not available for anyone
to copy, free of charge and under the terms of this License, through a
publicly available network server or other readily accessible means,
then you must either (1) cause the Corresponding Source to be so
available, or (2) arrange to deprive yourself of the benefit of the
patent license for this particular work, or (3) arrange, in a manner
consistent with the requirements of this License, to extend the patent
license to downstream recipients. "Knowingly relying" means you have
actual knowledge that, but for the patent license, your conveying the
covered work in a country, or your recipient's use of the covered work
in a country, would infringe one or more identifiable patents in that
country that you have reason to believe are valid.
If, pursuant to or in connection with a single transaction or
arrangement, you convey, or propagate by procuring conveyance of, a
covered work, and grant a patent license to some of the parties
receiving the covered work authorizing them to use, propagate, modify
or convey a specific copy of the covered work, then the patent license
you grant is automatically extended to all recipients of the covered
work and works based on it.
A patent license is "discriminatory" if it does not include within
the scope of its coverage, prohibits the exercise of, or is
conditioned on the non-exercise of one or more of the rights that are
specifically granted under this License. You may not convey a covered
work if you are a party to an arrangement with a third party that is
in the business of distributing software, under which you make payment
to the third party based on the extent of your activity of conveying
the work, and under which the third party grants, to any of the
parties who would receive the covered work from you, a discriminatory
patent license (a) in connection with copies of the covered work
conveyed by you (or copies made from those copies), or (b) primarily
for and in connection with specific products or compilations that
contain the covered work, unless you entered into that arrangement,
or that patent license was granted, prior to 28 March 2007.
Nothing in this License shall be construed as excluding or limiting
any implied license or other defenses to infringement that may
otherwise be available to you under applicable patent law.
12. No Surrender of Others' Freedom.
If conditions are imposed on you (whether by court order, agreement or
otherwise) that contradict the conditions of this License, they do not
excuse you from the conditions of this License. If you cannot convey a
covered work so as to satisfy simultaneously your obligations under this
License and any other pertinent obligations, then as a consequence you may
not convey it at all. For example, if you agree to terms that obligate you
to collect a royalty for further conveying from those to whom you convey
the Program, the only way you could satisfy both those terms and this
License would be to refrain entirely from conveying the Program.
13. Remote Network Interaction; Use with the GNU General Public License.
Notwithstanding any other provision of this License, if you modify the
Program, your modified version must prominently offer all users
interacting with it remotely through a computer network (if your version
supports such interaction) an opportunity to receive the Corresponding
Source of your version by providing access to the Corresponding Source
from a network server at no charge, through some standard or customary
means of facilitating copying of software. This Corresponding Source
shall include the Corresponding Source for any work covered by version 3
of the GNU General Public License that is incorporated pursuant to the
following paragraph.
Notwithstanding any other provision of this License, you have
permission to link or combine any covered work with a work licensed
under version 3 of the GNU General Public License into a single
combined work, and to convey the resulting work. The terms of this
License will continue to apply to the part which is the covered work,
but the work with which it is combined will remain governed by version
3 of the GNU General Public License.
14. Revised Versions of this License.
The Free Software Foundation may publish revised and/or new versions of
the GNU Affero General Public License from time to time. Such new versions
will be similar in spirit to the present version, but may differ in detail to
address new problems or concerns.
Each version is given a distinguishing version number. If the
Program specifies that a certain numbered version of the GNU Affero General
Public License "or any later version" applies to it, you have the
option of following the terms and conditions either of that numbered
version or of any later version published by the Free Software
Foundation. If the Program does not specify a version number of the
GNU Affero General Public License, you may choose any version ever published
by the Free Software Foundation.
If the Program specifies that a proxy can decide which future
versions of the GNU Affero General Public License can be used, that proxy's
public statement of acceptance of a version permanently authorizes you
to choose that version for the Program.
Later license versions may give you additional or different
permissions. However, no additional obligations are imposed on any
author or copyright holder as a result of your choosing to follow a
later version.
15. Disclaimer of Warranty.
THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY
APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT
HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY
OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO,
THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM
IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF
ALL NECESSARY SERVICING, REPAIR OR CORRECTION.
16. Limitation of Liability.
IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING
WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS
THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY
GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE
USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF
DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD
PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS),
EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF
SUCH DAMAGES.
17. Interpretation of Sections 15 and 16.
If the disclaimer of warranty and limitation of liability provided
above cannot be given local legal effect according to their terms,
reviewing courts shall apply local law that most closely approximates
an absolute waiver of all civil liability in connection with the
Program, unless a warranty or assumption of liability accompanies a
copy of the Program in return for a fee.
END OF TERMS AND CONDITIONS
How to Apply These Terms to Your New Programs
If you develop a new program, and you want it to be of the greatest
possible use to the public, the best way to achieve this is to make it
free software which everyone can redistribute and change under these terms.
To do so, attach the following notices to the program. It is safest
to attach them to the start of each source file to most effectively
state the exclusion of warranty; and each file should have at least
the "copyright" line and a pointer to where the full notice is found.
<one line to give the program's name and a brief idea of what it does.>
Copyright (C) <year> <name of author>
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published
by the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
Also add information on how to contact you by electronic and paper mail.
If your software can interact with users remotely through a computer
network, you should also make sure that it provides a way for users to
get its source. For example, if your program is a web application, its
interface could display a "Source" link that leads users to an archive
of the code. There are many ways you could offer source, and different
solutions will be better for different programs; see section 13 for the
specific requirements.
You should also get your employer (if you work as a programmer) or school,
if any, to sign a "copyright disclaimer" for the program, if necessary.
For more information on this, and how to apply and follow the GNU AGPL, see
<https://www.gnu.org/licenses/>.
+94
View File
@@ -0,0 +1,94 @@
#[cfg(windows)]
fn build_windows() {
let file = "src/platform/windows.cc";
let file2 = "src/platform/windows_delete_test_cert.cc";
cc::Build::new().file(file).file(file2).compile("windows");
println!("cargo:rustc-link-lib=WtsApi32");
println!("cargo:rerun-if-changed={}", file);
println!("cargo:rerun-if-changed={}", file2);
}
#[cfg(target_os = "macos")]
fn build_mac() {
let file = "src/platform/macos.mm";
let mut b = cc::Build::new();
if let Ok(os_version::OsVersion::MacOS(v)) = os_version::detect() {
let v = v.version;
if v.contains("10.14") {
b.flag("-DNO_InputMonitoringAuthStatus=1");
}
}
b.flag("-std=c++17").file(file).compile("macos");
println!("cargo:rerun-if-changed={}", file);
}
#[cfg(all(windows, feature = "inline"))]
fn build_manifest() {
use std::io::Write;
if std::env::var("PROFILE").unwrap() == "release" {
let mut res = winres::WindowsResource::new();
res.set_icon("res/icon.ico")
.set_language(winapi::um::winnt::MAKELANGID(
winapi::um::winnt::LANG_ENGLISH,
winapi::um::winnt::SUBLANG_ENGLISH_US,
))
.set_manifest_file("res/manifest.xml");
match res.compile() {
Err(e) => {
write!(std::io::stderr(), "{}", e).unwrap();
std::process::exit(1);
}
Ok(_) => {}
}
}
}
fn install_android_deps() {
let target_os = std::env::var("CARGO_CFG_TARGET_OS").unwrap();
if target_os != "android" {
return;
}
let mut target_arch = std::env::var("CARGO_CFG_TARGET_ARCH").unwrap();
if target_arch == "x86_64" {
target_arch = "x64".to_owned();
} else if target_arch == "x86" {
target_arch = "x86".to_owned();
} else if target_arch == "aarch64" {
target_arch = "arm64".to_owned();
} else {
target_arch = "arm".to_owned();
}
let target = format!("{}-android", target_arch);
let vcpkg_root = std::env::var("VCPKG_ROOT").unwrap();
let mut path: std::path::PathBuf = vcpkg_root.into();
if let Ok(vcpkg_root) = std::env::var("VCPKG_INSTALLED_ROOT") {
path = vcpkg_root.into();
} else {
path.push("installed");
}
path.push(target);
println!(
"cargo:rustc-link-search={}",
path.join("lib").to_str().unwrap()
);
println!("cargo:rustc-link-lib=ndk_compat");
println!("cargo:rustc-link-lib=oboe");
println!("cargo:rustc-link-lib=c++");
println!("cargo:rustc-link-lib=OpenSLES");
}
fn main() {
hbb_common::gen_version();
install_android_deps();
#[cfg(all(windows, feature = "inline"))]
build_manifest();
#[cfg(windows)]
build_windows();
let target_os = std::env::var("CARGO_CFG_TARGET_OS").unwrap();
if target_os == "macos" {
#[cfg(target_os = "macos")]
build_mac();
println!("cargo:rustc-link-lib=framework=ApplicationServices");
}
println!("cargo:rerun-if-changed=build.rs");
}
+57
View File
@@ -0,0 +1,57 @@
[package]
name = "clipboard"
version = "0.1.0"
edition = "2021"
build = "build.rs"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[build-dependencies]
cc = "1.0"
[features]
default = []
unix-file-copy-paste = [
"dep:x11rb",
"dep:x11-clipboard",
"dep:rand",
"dep:fuser",
"dep:libc",
"dep:dashmap",
"dep:percent-encoding",
"dep:utf16string",
"dep:once_cell",
"dep:cacao"
]
[dependencies]
thiserror = "1.0"
lazy_static = "1.4"
serde = "1.0"
serde_derive = "1.0"
hbb_common = { path = "../hbb_common" }
parking_lot = {version = "0.12"}
[target.'cfg(any(target_os = "linux", target_os = "macos"))'.dependencies]
rand = {version = "0.8", optional = true}
libc = {version = "0.2", optional = true}
dashmap = {version ="5.5", optional = true}
utf16string = {version = "0.2", optional = true}
once_cell = {version = "1.18", optional = true}
[target.'cfg(target_os = "linux")'.dependencies]
percent-encoding = {version ="2.3", optional = true}
x11-clipboard = {git="https://github.com/clslaid/x11-clipboard", branch = "feat/store-batch", optional = true}
x11rb = {version = "0.12", features = ["all-extensions"], optional = true}
fuser = {version = "0.15", default-features = false, optional = true}
[target.'cfg(target_os = "macos")'.dependencies]
cacao = {git="https://github.com/clslaid/cacao", branch = "feat/set-file-urls", optional = true}
# Use `relax-void-encoding`, as that allows us to pass `c_void` instead of implementing `Encode` correctly for `&CGImageRef`
objc2 = { version = "0.5.1", features = ["relax-void-encoding"] }
objc2-foundation = { version = "0.2.0", features = ["NSArray", "NSString", "NSEnumerator", "NSGeometry", "NSProgress"] }
objc2-app-kit = { version = "0.2.0", features = ["NSPasteboard", "NSPasteboardItem", "NSImage", "NSFilePromiseProvider"] }
uuid = { version = "1.3", features = ["v4"] }
fsevent = "2.1.2"
dirs = "5.0"
xattr = "1.4.0"
+35
View File
@@ -0,0 +1,35 @@
#[cfg(target_os = "windows")]
fn build_c_impl() {
let mut build = cc::Build::new();
build.file("src/windows/wf_cliprdr.c");
{
build.flag_if_supported("-Wno-c++0x-extensions");
build.flag_if_supported("-Wno-return-type-c-linkage");
build.flag_if_supported("-Wno-invalid-offsetof");
build.flag_if_supported("-Wno-unused-parameter");
if build.get_compiler().is_like_msvc() {
build.define("WIN32", "");
// build.define("_AMD64_", "");
build.flag("-Z7");
build.flag("-GR-");
// build.flag("-std:c++11");
} else {
build.flag("-fPIC");
// build.flag("-std=c++11");
// build.flag("-include");
// build.flag(&confdefs_path.to_string_lossy());
}
build.compile("mycliprdr");
}
println!("cargo:rerun-if-changed=src/windows/wf_cliprdr.c");
}
fn main() {
#[cfg(target_os = "windows")]
build_c_impl();
}
+247
View File
@@ -0,0 +1,247 @@
#ifndef WF_CLIPRDR_H__
#define WF_CLIPRDR_H__
#ifdef __cplusplus
extern "C"
{
#endif
typedef signed char INT8, *PINT8;
typedef signed short INT16, *PINT16;
typedef signed int INT32, *PINT32;
typedef unsigned char UINT8, *PUINT8;
typedef unsigned short UINT16, *PUINT16;
typedef unsigned int UINT32, *PUINT32;
typedef unsigned int UINT;
typedef int BOOL;
typedef unsigned char BYTE;
/* Clipboard Messages */
#define DEFINE_CLIPRDR_HEADER_COMMON() \
UINT32 connID; \
UINT16 msgType; \
UINT16 msgFlags; \
UINT32 dataLen
struct _CLIPRDR_HEADER
{
DEFINE_CLIPRDR_HEADER_COMMON();
};
typedef struct _CLIPRDR_HEADER CLIPRDR_HEADER;
struct _CLIPRDR_CAPABILITY_SET
{
UINT16 capabilitySetType;
UINT16 capabilitySetLength;
};
typedef struct _CLIPRDR_CAPABILITY_SET CLIPRDR_CAPABILITY_SET;
struct _CLIPRDR_GENERAL_CAPABILITY_SET
{
UINT16 capabilitySetType;
UINT16 capabilitySetLength;
UINT32 version;
UINT32 generalFlags;
};
typedef struct _CLIPRDR_GENERAL_CAPABILITY_SET CLIPRDR_GENERAL_CAPABILITY_SET;
struct _CLIPRDR_CAPABILITIES
{
DEFINE_CLIPRDR_HEADER_COMMON();
UINT32 cCapabilitiesSets;
CLIPRDR_CAPABILITY_SET *capabilitySets;
};
typedef struct _CLIPRDR_CAPABILITIES CLIPRDR_CAPABILITIES;
struct _CLIPRDR_MONITOR_READY
{
DEFINE_CLIPRDR_HEADER_COMMON();
};
typedef struct _CLIPRDR_MONITOR_READY CLIPRDR_MONITOR_READY;
struct _CLIPRDR_TEMP_DIRECTORY
{
DEFINE_CLIPRDR_HEADER_COMMON();
char szTempDir[520];
};
typedef struct _CLIPRDR_TEMP_DIRECTORY CLIPRDR_TEMP_DIRECTORY;
struct _CLIPRDR_FORMAT
{
UINT32 formatId;
char *formatName;
};
typedef struct _CLIPRDR_FORMAT CLIPRDR_FORMAT;
struct _CLIPRDR_FORMAT_LIST
{
DEFINE_CLIPRDR_HEADER_COMMON();
UINT32 numFormats;
CLIPRDR_FORMAT *formats;
};
typedef struct _CLIPRDR_FORMAT_LIST CLIPRDR_FORMAT_LIST;
struct _CLIPRDR_FORMAT_LIST_RESPONSE
{
DEFINE_CLIPRDR_HEADER_COMMON();
};
typedef struct _CLIPRDR_FORMAT_LIST_RESPONSE CLIPRDR_FORMAT_LIST_RESPONSE;
struct _CLIPRDR_LOCK_CLIPBOARD_DATA
{
DEFINE_CLIPRDR_HEADER_COMMON();
UINT32 clipDataId;
};
typedef struct _CLIPRDR_LOCK_CLIPBOARD_DATA CLIPRDR_LOCK_CLIPBOARD_DATA;
struct _CLIPRDR_UNLOCK_CLIPBOARD_DATA
{
DEFINE_CLIPRDR_HEADER_COMMON();
UINT32 clipDataId;
};
typedef struct _CLIPRDR_UNLOCK_CLIPBOARD_DATA CLIPRDR_UNLOCK_CLIPBOARD_DATA;
struct _CLIPRDR_FORMAT_DATA_REQUEST
{
DEFINE_CLIPRDR_HEADER_COMMON();
UINT32 requestedFormatId;
};
typedef struct _CLIPRDR_FORMAT_DATA_REQUEST CLIPRDR_FORMAT_DATA_REQUEST;
struct _CLIPRDR_FORMAT_DATA_RESPONSE
{
DEFINE_CLIPRDR_HEADER_COMMON();
const BYTE *requestedFormatData;
};
typedef struct _CLIPRDR_FORMAT_DATA_RESPONSE CLIPRDR_FORMAT_DATA_RESPONSE;
struct _CLIPRDR_FILE_CONTENTS_REQUEST
{
DEFINE_CLIPRDR_HEADER_COMMON();
UINT32 streamId;
UINT32 listIndex;
UINT32 dwFlags;
UINT32 nPositionLow;
UINT32 nPositionHigh;
UINT32 cbRequested;
BOOL haveClipDataId;
UINT32 clipDataId;
};
typedef struct _CLIPRDR_FILE_CONTENTS_REQUEST CLIPRDR_FILE_CONTENTS_REQUEST;
struct _CLIPRDR_FILE_CONTENTS_RESPONSE
{
DEFINE_CLIPRDR_HEADER_COMMON();
UINT32 streamId;
UINT32 cbRequested;
const BYTE *requestedData;
};
typedef struct _CLIPRDR_FILE_CONTENTS_RESPONSE CLIPRDR_FILE_CONTENTS_RESPONSE;
typedef struct _cliprdr_client_context CliprdrClientContext;
struct _NOTIFICATION_MESSAGE
{
// 0 - info, 1 - warning, 2 - error
UINT32 type;
char *msg;
char *details;
};
typedef struct _NOTIFICATION_MESSAGE NOTIFICATION_MESSAGE;
typedef UINT (*pcCliprdrServerCapabilities)(CliprdrClientContext *context,
const CLIPRDR_CAPABILITIES *capabilities);
typedef UINT (*pcCliprdrClientCapabilities)(CliprdrClientContext *context,
const CLIPRDR_CAPABILITIES *capabilities);
typedef UINT (*pcCliprdrMonitorReady)(CliprdrClientContext *context,
const CLIPRDR_MONITOR_READY *monitorReady);
typedef UINT (*pcCliprdrTempDirectory)(CliprdrClientContext *context,
const CLIPRDR_TEMP_DIRECTORY *tempDirectory);
typedef UINT (*pcNotifyClipboardMsg)(UINT32 connID, const NOTIFICATION_MESSAGE *msg);
typedef UINT (*pcHandleClipboardFiles)(UINT32 connID, size_t nFiles, WCHAR **fileNames);
typedef UINT (*pcCliprdrClientFormatList)(CliprdrClientContext *context,
const CLIPRDR_FORMAT_LIST *formatList);
typedef UINT (*pcCliprdrServerFormatList)(CliprdrClientContext *context,
const CLIPRDR_FORMAT_LIST *formatList);
typedef UINT (*pcCliprdrClientFormatListResponse)(
CliprdrClientContext *context, const CLIPRDR_FORMAT_LIST_RESPONSE *formatListResponse);
typedef UINT (*pcCliprdrServerFormatListResponse)(
CliprdrClientContext *context, const CLIPRDR_FORMAT_LIST_RESPONSE *formatListResponse);
typedef UINT (*pcCliprdrClientLockClipboardData)(
CliprdrClientContext *context, const CLIPRDR_LOCK_CLIPBOARD_DATA *lockClipboardData);
typedef UINT (*pcCliprdrServerLockClipboardData)(
CliprdrClientContext *context, const CLIPRDR_LOCK_CLIPBOARD_DATA *lockClipboardData);
typedef UINT (*pcCliprdrClientUnlockClipboardData)(
CliprdrClientContext *context, const CLIPRDR_UNLOCK_CLIPBOARD_DATA *unlockClipboardData);
typedef UINT (*pcCliprdrServerUnlockClipboardData)(
CliprdrClientContext *context, const CLIPRDR_UNLOCK_CLIPBOARD_DATA *unlockClipboardData);
typedef UINT (*pcCliprdrClientFormatDataRequest)(
CliprdrClientContext *context, const CLIPRDR_FORMAT_DATA_REQUEST *formatDataRequest);
typedef UINT (*pcCliprdrServerFormatDataRequest)(
CliprdrClientContext *context, const CLIPRDR_FORMAT_DATA_REQUEST *formatDataRequest);
typedef UINT (*pcCliprdrClientFormatDataResponse)(
CliprdrClientContext *context, const CLIPRDR_FORMAT_DATA_RESPONSE *formatDataResponse);
typedef UINT (*pcCliprdrServerFormatDataResponse)(
CliprdrClientContext *context, const CLIPRDR_FORMAT_DATA_RESPONSE *formatDataResponse);
typedef UINT (*pcCliprdrClientFileContentsRequest)(
CliprdrClientContext *context, const CLIPRDR_FILE_CONTENTS_REQUEST *fileContentsRequest);
typedef UINT (*pcCliprdrServerFileContentsRequest)(
CliprdrClientContext *context, const CLIPRDR_FILE_CONTENTS_REQUEST *fileContentsRequest);
typedef UINT (*pcCliprdrClientFileContentsResponse)(
CliprdrClientContext *context, const CLIPRDR_FILE_CONTENTS_RESPONSE *fileContentsResponse);
typedef UINT (*pcCliprdrServerFileContentsResponse)(
CliprdrClientContext *context, const CLIPRDR_FILE_CONTENTS_RESPONSE *fileContentsResponse);
// TODO: hide more members of clipboard context
struct _cliprdr_client_context
{
void *Custom;
BOOL EnableFiles;
BOOL EnableOthers;
BOOL IsStopped;
UINT32 ResponseWaitTimeoutSecs;
pcCliprdrServerCapabilities ServerCapabilities;
pcCliprdrClientCapabilities ClientCapabilities;
pcCliprdrMonitorReady MonitorReady;
pcCliprdrTempDirectory TempDirectory;
pcNotifyClipboardMsg NotifyClipboardMsg;
pcHandleClipboardFiles HandleClipboardFiles;
pcCliprdrClientFormatList ClientFormatList;
pcCliprdrServerFormatList ServerFormatList;
pcCliprdrClientFormatListResponse ClientFormatListResponse;
pcCliprdrServerFormatListResponse ServerFormatListResponse;
pcCliprdrClientLockClipboardData ClientLockClipboardData;
pcCliprdrServerLockClipboardData ServerLockClipboardData;
pcCliprdrClientUnlockClipboardData ClientUnlockClipboardData;
pcCliprdrServerUnlockClipboardData ServerUnlockClipboardData;
pcCliprdrClientFormatDataRequest ClientFormatDataRequest;
pcCliprdrServerFormatDataRequest ServerFormatDataRequest;
pcCliprdrClientFormatDataResponse ClientFormatDataResponse;
pcCliprdrServerFormatDataResponse ServerFormatDataResponse;
pcCliprdrClientFileContentsRequest ClientFileContentsRequest;
pcCliprdrServerFileContentsRequest ServerFileContentsRequest;
pcCliprdrClientFileContentsResponse ClientFileContentsResponse;
pcCliprdrServerFileContentsResponse ServerFileContentsResponse;
UINT32 LastRequestedFormatId;
};
#ifdef __cplusplus
}
#endif
#endif // WF_CLIPRDR_H__
+79
View File
@@ -0,0 +1,79 @@
use hbb_common::{log, ResultType};
use std::{ops::Deref, sync::Mutex};
use crate::CliprdrServiceContext;
const CLIPBOARD_RESPONSE_WAIT_TIMEOUT_SECS: u32 = 30;
lazy_static::lazy_static! {
static ref CONTEXT_SEND: ContextSend = ContextSend::default();
}
#[derive(Default)]
pub struct ContextSend(Mutex<Option<Box<dyn CliprdrServiceContext>>>);
impl Deref for ContextSend {
type Target = Mutex<Option<Box<dyn CliprdrServiceContext>>>;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl ContextSend {
#[inline]
pub fn is_enabled() -> bool {
CONTEXT_SEND.lock().unwrap().is_some()
}
pub fn set_is_stopped() {
let _res = Self::proc(|c| c.set_is_stopped().map_err(|e| e.into()));
}
pub fn enable(enabled: bool) {
let mut lock = CONTEXT_SEND.lock().unwrap();
if enabled {
if lock.is_some() {
return;
}
match crate::create_cliprdr_context(true, false, CLIPBOARD_RESPONSE_WAIT_TIMEOUT_SECS) {
Ok(context) => {
log::info!("clipboard context for file transfer created.");
*lock = Some(context)
}
Err(err) => {
log::error!(
"create clipboard context for file transfer: {}",
err.to_string()
);
}
}
} else if let Some(_clp) = lock.take() {
*lock = None;
log::info!("clipboard context for file transfer destroyed.");
}
}
/// make sure the clipboard context is enabled.
pub fn make_sure_enabled() -> ResultType<()> {
let mut lock = CONTEXT_SEND.lock().unwrap();
if lock.is_some() {
return Ok(());
}
let ctx = crate::create_cliprdr_context(true, false, CLIPBOARD_RESPONSE_WAIT_TIMEOUT_SECS)?;
*lock = Some(ctx);
log::info!("clipboard context for file transfer recreated.");
Ok(())
}
pub fn proc<F: FnOnce(&mut Box<dyn CliprdrServiceContext>) -> ResultType<()>>(
f: F,
) -> ResultType<()> {
let mut lock = CONTEXT_SEND.lock().unwrap();
match lock.as_mut() {
Some(context) => f(context),
None => Ok(()),
}
}
}
+298
View File
@@ -0,0 +1,298 @@
use std::sync::{Arc, Mutex, RwLock};
#[cfg(any(
target_os = "windows",
all(target_os = "macos", feature = "unix-file-copy-paste")
))]
use hbb_common::ResultType;
#[cfg(any(target_os = "windows", feature = "unix-file-copy-paste"))]
use hbb_common::{allow_err, log};
use hbb_common::{
lazy_static,
tokio::sync::{
mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender},
Mutex as TokioMutex,
},
};
use serde_derive::{Deserialize, Serialize};
use thiserror::Error;
#[cfg(any(
target_os = "windows",
all(target_os = "macos", feature = "unix-file-copy-paste")
))]
pub mod context_send;
pub mod platform;
#[cfg(any(
target_os = "windows",
all(target_os = "macos", feature = "unix-file-copy-paste")
))]
pub use context_send::*;
#[cfg(target_os = "windows")]
const ERR_CODE_SERVER_FUNCTION_NONE: u32 = 0x00000001;
#[cfg(target_os = "windows")]
const ERR_CODE_INVALID_PARAMETER: u32 = 0x00000002;
#[cfg(target_os = "windows")]
const ERR_CODE_SEND_MSG: u32 = 0x00000003;
#[cfg(any(
target_os = "windows",
all(target_os = "macos", feature = "unix-file-copy-paste")
))]
pub(crate) use platform::create_cliprdr_context;
pub struct ProgressPercent {
pub percent: f64,
pub is_canceled: bool,
pub is_failed: bool,
}
// to-do: This trait may be removed, because unix file copy paste does not need it.
/// Ability to handle Clipboard File from remote rustdesk client
///
/// # Note
/// There actually should be 2 parts to implement a useable clipboard file service,
/// but this only contains the RPC server part.
/// The local listener and transport part is too platform specific to wrap up in typeclasses.
pub trait CliprdrServiceContext: Send + Sync {
/// set to be stopped
fn set_is_stopped(&mut self) -> Result<(), CliprdrError>;
/// clear the content on clipboard
fn empty_clipboard(&mut self, conn_id: i32) -> Result<bool, CliprdrError>;
/// run as a server for clipboard RPC
fn server_clip_file(&mut self, conn_id: i32, msg: ClipboardFile) -> Result<(), CliprdrError>;
/// get the progress of the paste task.
fn get_progress_percent(&self) -> Option<ProgressPercent>;
/// cancel the paste task.
fn cancel(&mut self);
}
#[derive(Error, Debug)]
pub enum CliprdrError {
#[error("invalid cliprdr name")]
CliprdrName,
#[error("failed to init cliprdr")]
CliprdrInit,
#[error("cliprdr out of memory")]
CliprdrOutOfMemory,
#[error("cliprdr internal error")]
ClipboardInternalError,
#[error("cliprdr occupied")]
ClipboardOccupied,
#[error("conversion failure")]
ConversionFailure,
#[error("failure to read clipboard")]
OpenClipboard,
#[error("failure to read file metadata or content, path: {path}, err: {err}")]
FileError { path: String, err: std::io::Error },
#[error("invalid request: {description}")]
InvalidRequest { description: String },
#[error("common request: {description}")]
CommonError { description: String },
#[error("unknown cliprdr error")]
Unknown(u32),
}
#[derive(Debug, Serialize, Deserialize, Clone)]
#[serde(tag = "t", content = "c")]
pub enum ClipboardFile {
NotifyCallback {
r#type: String,
title: String,
text: String,
},
MonitorReady,
FormatList {
format_list: Vec<(i32, String)>,
},
FormatListResponse {
msg_flags: i32,
},
FormatDataRequest {
requested_format_id: i32,
},
FormatDataResponse {
msg_flags: i32,
format_data: Vec<u8>,
},
FileContentsRequest {
stream_id: i32,
list_index: i32,
dw_flags: i32,
n_position_low: i32,
n_position_high: i32,
cb_requested: i32,
have_clip_data_id: bool,
clip_data_id: i32,
},
FileContentsResponse {
msg_flags: i32,
stream_id: i32,
requested_data: Vec<u8>,
},
TryEmpty,
Files {
files: Vec<(String, u64)>,
},
}
struct MsgChannel {
peer_id: String,
conn_id: i32,
#[allow(dead_code)]
sender: UnboundedSender<ClipboardFile>,
receiver: Arc<TokioMutex<UnboundedReceiver<ClipboardFile>>>,
}
lazy_static::lazy_static! {
static ref VEC_MSG_CHANNEL: RwLock<Vec<MsgChannel>> = Default::default();
static ref CLIENT_CONN_ID_COUNTER: Mutex<i32> = Mutex::new(0);
}
impl ClipboardFile {
pub fn is_stopping_allowed(&self) -> bool {
matches!(
self,
ClipboardFile::MonitorReady
| ClipboardFile::FormatList { .. }
| ClipboardFile::FormatDataRequest { .. }
)
}
pub fn is_beginning_message(&self) -> bool {
matches!(
self,
ClipboardFile::MonitorReady | ClipboardFile::FormatList { .. }
)
}
}
pub fn get_client_conn_id(peer_id: &str) -> Option<i32> {
VEC_MSG_CHANNEL
.read()
.unwrap()
.iter()
.find(|x| x.peer_id == peer_id)
.map(|x| x.conn_id)
}
fn get_conn_id() -> i32 {
let mut lock = CLIENT_CONN_ID_COUNTER.lock().unwrap();
*lock += 1;
*lock
}
pub fn get_rx_cliprdr_client(
peer_id: &str,
) -> (i32, Arc<TokioMutex<UnboundedReceiver<ClipboardFile>>>) {
let mut lock = VEC_MSG_CHANNEL.write().unwrap();
match lock.iter().find(|x| x.peer_id == peer_id) {
Some(msg_channel) => (msg_channel.conn_id, msg_channel.receiver.clone()),
None => {
let (sender, receiver) = unbounded_channel();
let receiver = Arc::new(TokioMutex::new(receiver));
let receiver2 = receiver.clone();
let conn_id = get_conn_id();
let msg_channel = MsgChannel {
peer_id: peer_id.to_owned(),
conn_id,
sender,
receiver,
};
lock.push(msg_channel);
(conn_id, receiver2)
}
}
}
pub fn get_rx_cliprdr_server(conn_id: i32) -> Arc<TokioMutex<UnboundedReceiver<ClipboardFile>>> {
let mut lock = VEC_MSG_CHANNEL.write().unwrap();
match lock.iter().find(|x| x.conn_id == conn_id) {
Some(msg_channel) => msg_channel.receiver.clone(),
None => {
let (sender, receiver) = unbounded_channel();
let receiver = Arc::new(TokioMutex::new(receiver));
let receiver2 = receiver.clone();
let msg_channel = MsgChannel {
peer_id: "".to_string(),
conn_id,
sender,
receiver,
};
lock.push(msg_channel);
receiver2
}
}
}
pub fn remove_channel_by_conn_id(conn_id: i32) {
let mut lock = VEC_MSG_CHANNEL.write().unwrap();
if let Some(index) = lock.iter().position(|x| x.conn_id == conn_id) {
lock.remove(index);
}
}
#[cfg(any(target_os = "windows", feature = "unix-file-copy-paste"))]
#[inline]
pub fn send_data(conn_id: i32, data: ClipboardFile) -> Result<(), CliprdrError> {
#[cfg(target_os = "windows")]
return send_data_to_channel(conn_id, data);
#[cfg(not(target_os = "windows"))]
if conn_id == 0 {
let _ = send_data_to_all(data);
Ok(())
} else {
send_data_to_channel(conn_id, data)
}
}
#[inline]
#[cfg(any(target_os = "windows", feature = "unix-file-copy-paste"))]
fn send_data_to_channel(conn_id: i32, data: ClipboardFile) -> Result<(), CliprdrError> {
if let Some(msg_channel) = VEC_MSG_CHANNEL
.read()
.unwrap()
.iter()
.find(|x| x.conn_id == conn_id)
{
msg_channel
.sender
.send(data)
.map_err(|e| CliprdrError::CommonError {
description: e.to_string(),
})
} else {
Err(CliprdrError::InvalidRequest {
description: "conn_id not found".to_string(),
})
}
}
#[inline]
#[cfg(target_os = "windows")]
pub fn send_data_exclude(conn_id: i32, data: ClipboardFile) {
// Need more tests to see if it's necessary to handle the error.
for msg_channel in VEC_MSG_CHANNEL.read().unwrap().iter() {
if msg_channel.conn_id != conn_id {
allow_err!(msg_channel.sender.send(data.clone()));
}
}
}
#[inline]
#[cfg(feature = "unix-file-copy-paste")]
fn send_data_to_all(data: ClipboardFile) {
// Need more tests to see if it's necessary to handle the error.
for msg_channel in VEC_MSG_CHANNEL.read().unwrap().iter() {
allow_err!(msg_channel.sender.send(data.clone()));
}
}
#[cfg(test)]
mod tests {
// #[test]
// fn test_cliprdr_run() {
// super::cliprdr_run();
// }
}
+26
View File
@@ -0,0 +1,26 @@
#[cfg(target_os = "windows")]
pub mod windows;
#[cfg(target_os = "windows")]
pub fn create_cliprdr_context(
enable_files: bool,
enable_others: bool,
response_wait_timeout_secs: u32,
) -> crate::ResultType<Box<dyn crate::CliprdrServiceContext>> {
let boxed =
windows::create_cliprdr_context(enable_files, enable_others, response_wait_timeout_secs)?
as Box<_>;
Ok(boxed)
}
#[cfg(feature = "unix-file-copy-paste")]
pub mod unix;
#[cfg(all(feature = "unix-file-copy-paste", target_os = "macos"))]
pub fn create_cliprdr_context(
_enable_files: bool,
_enable_others: bool,
_response_wait_timeout_secs: u32,
) -> crate::ResultType<Box<dyn crate::CliprdrServiceContext>> {
let boxed = unix::macos::pasteboard_context::create_pasteboard_context()? as Box<_>;
Ok(boxed)
}
@@ -0,0 +1,188 @@
use super::{FLAGS_FD_ATTRIBUTES, FLAGS_FD_LAST_WRITE, FLAGS_FD_UNIX_MODE, LDAP_EPOCH_DELTA};
use crate::CliprdrError;
use hbb_common::{
bytes::{Buf, Bytes},
log,
};
use serde_derive::{Deserialize, Serialize};
use std::{
path::PathBuf,
time::{Duration, SystemTime},
};
use utf16string::WStr;
#[cfg(target_os = "linux")]
pub type Inode = u64;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum FileType {
File,
Directory,
// todo: support symlink
Symlink,
}
/// read only permission
pub const PERM_READ: u16 = 0o444;
/// read and write permission
pub const PERM_RW: u16 = 0o644;
/// only self can read and readonly
pub const PERM_SELF_RO: u16 = 0o400;
/// rwx
pub const PERM_RWX: u16 = 0o755;
#[allow(dead_code)]
/// max length of file name
pub const MAX_NAME_LEN: usize = 255;
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct FileDescription {
pub conn_id: i32,
pub name: PathBuf,
pub kind: FileType,
pub atime: SystemTime,
pub last_modified: SystemTime,
pub last_metadata_changed: SystemTime,
pub creation_time: SystemTime,
pub size: u64,
pub perm: u16,
}
impl FileDescription {
fn parse_file_descriptor(
bytes: &mut Bytes,
conn_id: i32,
) -> Result<FileDescription, CliprdrError> {
let flags = bytes.get_u32_le();
// skip reserved 32 bytes
bytes.advance(32);
let attributes = bytes.get_u32_le();
// in original specification, this is 16 bytes reserved
// we use the last 4 bytes to store the file mode
// skip reserved 12 bytes
bytes.advance(12);
let perm = bytes.get_u32_le() as u16;
// last write time from 1601-01-01 00:00:00, in 100ns
let last_write_time = bytes.get_u64_le();
// file size
let file_size_high = bytes.get_u32_le();
let file_size_low = bytes.get_u32_le();
// utf16 file name, double \0 terminated, in 520 bytes block
// read with another pointer, and advance the main pointer
let block = bytes.clone();
bytes.advance(520);
let block = &block[..520];
let wstr = WStr::from_utf16le(block).map_err(|e| {
log::error!("cannot convert file descriptor path: {:?}", e);
CliprdrError::ConversionFailure
})?;
let from_unix = flags & FLAGS_FD_UNIX_MODE != 0;
let valid_attributes = flags & FLAGS_FD_ATTRIBUTES != 0;
if !valid_attributes {
return Err(CliprdrError::InvalidRequest {
description: "file description must have valid attributes".to_string(),
});
}
// todo: check normal, hidden, system, readonly, archive...
let directory = attributes & 0x10 != 0;
let normal = attributes == 0x80;
let hidden = attributes & 0x02 != 0;
let readonly = attributes & 0x01 != 0;
let perm = if from_unix {
// as is
perm
// cannot set as is...
} else if normal {
PERM_RWX
} else if readonly {
PERM_READ
} else if hidden {
PERM_SELF_RO
} else if directory {
PERM_RWX
} else {
PERM_RW
};
let kind = if directory {
FileType::Directory
} else {
FileType::File
};
// to-do: use `let valid_size = flags & FLAGS_FD_SIZE != 0;`
// We use `true` to for compatibility with Windows.
// let valid_size = flags & FLAGS_FD_SIZE != 0;
let valid_size = true;
let size = if valid_size {
((file_size_high as u64) << 32) + file_size_low as u64
} else {
0
};
let valid_write_time = flags & FLAGS_FD_LAST_WRITE != 0;
let last_modified = if valid_write_time && last_write_time >= LDAP_EPOCH_DELTA {
let last_write_time = (last_write_time - LDAP_EPOCH_DELTA) * 100;
let last_write_time = Duration::from_nanos(last_write_time);
SystemTime::UNIX_EPOCH + last_write_time
} else {
SystemTime::UNIX_EPOCH
};
let name = wstr.to_utf8().replace('\\', "/");
let name = PathBuf::from(name.trim_end_matches('\0'));
let desc = FileDescription {
conn_id,
name,
kind,
atime: last_modified,
last_modified,
last_metadata_changed: last_modified,
creation_time: last_modified,
size,
perm,
};
Ok(desc)
}
/// parse file descriptions from a format data response PDU
/// which containing a CSPTR_FILEDESCRIPTORW indicated format data
pub fn parse_file_descriptors(
file_descriptor_pdu: Vec<u8>,
conn_id: i32,
) -> Result<Vec<Self>, CliprdrError> {
let mut data = Bytes::from(file_descriptor_pdu);
if data.remaining() < 4 {
return Err(CliprdrError::InvalidRequest {
description: "file descriptor request with infficient length".to_string(),
});
}
let count = data.get_u32_le() as usize;
if data.remaining() == 0 && count == 0 {
return Ok(Vec::new());
}
if data.remaining() != 592 * count {
return Err(CliprdrError::InvalidRequest {
description: "file descriptor request with invalid length".to_string(),
});
}
let mut files = Vec::with_capacity(count);
for _ in 0..count {
let desc = Self::parse_file_descriptor(&mut data, conn_id)?;
files.push(desc);
}
Ok(files)
}
}
File diff suppressed because it is too large Load Diff
@@ -0,0 +1,225 @@
mod cs;
use super::filetype::FileDescription;
use crate::{ClipboardFile, CliprdrError};
use cs::FuseServer;
use fuser::MountOption;
use hbb_common::{config::APP_NAME, log};
use parking_lot::Mutex;
use std::{
path::PathBuf,
sync::{mpsc::Sender, Arc},
time::Duration,
};
lazy_static::lazy_static! {
static ref FUSE_MOUNT_POINT_CLIENT: Arc<String> = {
let mnt_path = format!("/tmp/{}/{}", APP_NAME.read().unwrap(), "cliprdr-client");
// No need to run `canonicalize()` here.
Arc::new(mnt_path)
};
static ref FUSE_MOUNT_POINT_SERVER: Arc<String> = {
let mnt_path = format!("/tmp/{}/{}", APP_NAME.read().unwrap(), "cliprdr-server");
// No need to run `canonicalize()` here.
Arc::new(mnt_path)
};
static ref FUSE_CONTEXT_CLIENT: Arc<Mutex<Option<FuseContext>>> = Arc::new(Mutex::new(None));
static ref FUSE_CONTEXT_SERVER: Arc<Mutex<Option<FuseContext>>> = Arc::new(Mutex::new(None));
}
static FUSE_TIMEOUT: Duration = Duration::from_secs(3);
pub fn get_exclude_paths(is_client: bool) -> Arc<String> {
if is_client {
FUSE_MOUNT_POINT_CLIENT.clone()
} else {
FUSE_MOUNT_POINT_SERVER.clone()
}
}
pub fn is_fuse_context_inited(is_client: bool) -> bool {
if is_client {
FUSE_CONTEXT_CLIENT.lock().is_some()
} else {
FUSE_CONTEXT_SERVER.lock().is_some()
}
}
pub fn init_fuse_context(is_client: bool) -> Result<(), CliprdrError> {
let mut fuse_context_lock = if is_client {
FUSE_CONTEXT_CLIENT.lock()
} else {
FUSE_CONTEXT_SERVER.lock()
};
if fuse_context_lock.is_some() {
return Ok(());
}
let mount_point = if is_client {
FUSE_MOUNT_POINT_CLIENT.clone()
} else {
FUSE_MOUNT_POINT_SERVER.clone()
};
let mount_point = std::path::PathBuf::from(&*mount_point);
let (server, tx) = FuseServer::new(FUSE_TIMEOUT);
let server = Arc::new(Mutex::new(server));
prepare_fuse_mount_point(&mount_point);
let mnt_opts = [
MountOption::FSName("rustdesk-cliprdr-fs".to_string()),
MountOption::NoAtime,
MountOption::RO,
];
log::info!("mounting clipboard FUSE to {}", mount_point.display());
// to-do: ignore the error if the mount point is already mounted
// Because the sciter version uses separate processes as the controlling side.
let session = fuser::spawn_mount2(
FuseServer::client(server.clone()),
mount_point.clone(),
&mnt_opts,
)
.map_err(|e| {
log::error!("failed to mount cliprdr fuse: {:?}", e);
CliprdrError::CliprdrInit
})?;
let session = Mutex::new(Some(session));
let ctx = FuseContext {
server,
tx,
mount_point,
session,
conn_id: 0,
};
*fuse_context_lock = Some(ctx);
Ok(())
}
pub fn uninit_fuse_context(is_client: bool) {
uninit_fuse_context_(is_client)
}
pub fn format_data_response_to_urls(
is_client: bool,
format_data: Vec<u8>,
conn_id: i32,
) -> Result<Vec<String>, CliprdrError> {
let mut ctx = if is_client {
FUSE_CONTEXT_CLIENT.lock()
} else {
FUSE_CONTEXT_SERVER.lock()
};
ctx.as_mut()
.ok_or(CliprdrError::CliprdrInit)?
.format_data_response_to_urls(format_data, conn_id)
}
pub fn handle_file_content_response(
is_client: bool,
clip: ClipboardFile,
) -> Result<(), CliprdrError> {
// we don't know its corresponding request, no resend can be performed
let ctx = if is_client {
FUSE_CONTEXT_CLIENT.lock()
} else {
FUSE_CONTEXT_SERVER.lock()
};
ctx.as_ref()
.ok_or(CliprdrError::CliprdrInit)?
.tx
.send(clip)
.map_err(|e| {
log::error!("failed to send file contents response to fuse: {:?}", e);
CliprdrError::ClipboardInternalError
})?;
Ok(())
}
pub fn empty_local_files(is_client: bool, conn_id: i32) -> bool {
let ctx = if is_client {
FUSE_CONTEXT_CLIENT.lock()
} else {
FUSE_CONTEXT_SERVER.lock()
};
ctx.as_ref()
.map(|c| c.empty_local_files(conn_id))
.unwrap_or(false)
}
struct FuseContext {
server: Arc<Mutex<FuseServer>>,
tx: Sender<ClipboardFile>,
mount_point: PathBuf,
// stores fuse background session handle
session: Mutex<Option<fuser::BackgroundSession>>,
// Indicates the connection ID of that set the clipboard content
conn_id: i32,
}
// this function must be called after the main IPC is up
fn prepare_fuse_mount_point(mount_point: &PathBuf) {
use std::{
fs::{self, Permissions},
os::unix::prelude::PermissionsExt,
};
fs::create_dir(mount_point).ok();
fs::set_permissions(mount_point, Permissions::from_mode(0o777)).ok();
if let Err(e) = std::process::Command::new("umount")
.arg(mount_point)
.status()
{
log::warn!("umount {:?} may fail: {:?}", mount_point, e);
}
}
fn uninit_fuse_context_(is_client: bool) {
if is_client {
let _ = FUSE_CONTEXT_CLIENT.lock().take();
} else {
let _ = FUSE_CONTEXT_SERVER.lock().take();
}
}
impl Drop for FuseContext {
fn drop(&mut self) {
self.session.lock().take().map(|s| s.join());
log::info!("unmounting clipboard FUSE from {}", self.mount_point.display());
}
}
impl FuseContext {
pub fn empty_local_files(&self, conn_id: i32) -> bool {
if conn_id != 0 && self.conn_id != conn_id {
return false;
}
let mut fuse_guard = self.server.lock();
let _ = fuse_guard.load_file_list(vec![]);
true
}
pub fn format_data_response_to_urls(
&mut self,
format_data: Vec<u8>,
conn_id: i32,
) -> Result<Vec<String>, CliprdrError> {
let files = FileDescription::parse_file_descriptors(format_data, conn_id)?;
let paths = {
let mut fuse_guard = self.server.lock();
fuse_guard.load_file_list(files)?;
self.conn_id = conn_id;
fuse_guard.list_root()
};
let prefix = self.mount_point.clone();
Ok(paths
.into_iter()
.map(|p| prefix.join(p).to_string_lossy().to_string())
.collect())
}
}
@@ -0,0 +1,387 @@
use super::{BLOCK_SIZE, LDAP_EPOCH_DELTA};
use crate::{
platform::unix::{
FLAGS_FD_ATTRIBUTES, FLAGS_FD_LAST_WRITE, FLAGS_FD_PROGRESSUI, FLAGS_FD_SIZE,
FLAGS_FD_UNIX_MODE,
},
CliprdrError,
};
use hbb_common::{
bytes::{BufMut, BytesMut},
log,
};
use std::{
collections::HashSet,
fs::File,
io::{BufRead, BufReader, Read, Seek},
os::unix::prelude::PermissionsExt,
path::{Path, PathBuf},
sync::atomic::{AtomicU64, Ordering},
time::SystemTime,
};
use utf16string::WString;
#[derive(Debug)]
pub(super) struct LocalFile {
pub relative_root: PathBuf,
pub path: PathBuf,
pub handle: Option<BufReader<File>>,
pub offset: AtomicU64,
pub name: String,
pub size: u64,
pub last_write_time: SystemTime,
pub is_dir: bool,
pub perm: u32,
pub read_only: bool,
pub hidden: bool,
pub system: bool,
pub archive: bool,
pub normal: bool,
}
impl LocalFile {
pub fn try_open(relative_root: &Path, path: &Path) -> Result<Self, CliprdrError> {
let mt = std::fs::metadata(path).map_err(|e| CliprdrError::FileError {
path: path.to_string_lossy().to_string(),
err: e,
})?;
let size = mt.len() as u64;
let is_dir = mt.is_dir();
let read_only = mt.permissions().readonly();
let system = false;
let hidden = path.to_string_lossy().starts_with('.');
let archive = false;
let normal = !(is_dir || read_only || system || hidden || archive);
let last_write_time = mt.modified().unwrap_or(SystemTime::UNIX_EPOCH);
let perm = mt.permissions().mode();
let name = path
.display()
.to_string()
.trim_start_matches('/')
.replace('/', "\\");
// NOTE: open files lazily
let handle = None;
let offset = AtomicU64::new(0);
Ok(Self {
name,
relative_root: relative_root.to_path_buf(),
path: path.to_path_buf(),
handle,
offset,
size,
last_write_time,
is_dir,
read_only,
system,
hidden,
perm,
archive,
normal,
})
}
pub fn as_bin(&self) -> Vec<u8> {
let mut buf = BytesMut::with_capacity(592);
let read_only_flag = if self.read_only { 0x1 } else { 0 };
let hidden_flag = if self.hidden { 0x2 } else { 0 };
let system_flag = if self.system { 0x4 } else { 0 };
let directory_flag = if self.is_dir { 0x10 } else { 0 };
let archive_flag = if self.archive { 0x20 } else { 0 };
let normal_flag = if self.normal { 0x80 } else { 0 };
let file_attributes: u32 = read_only_flag
| hidden_flag
| system_flag
| directory_flag
| archive_flag
| normal_flag;
let win32_time = self
.last_write_time
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_nanos() as u64
/ 100
+ LDAP_EPOCH_DELTA;
let size_high = (self.size >> 32) as u32;
let size_low = (self.size & (u32::MAX as u64)) as u32;
let path = self
.path
.strip_prefix(&self.relative_root)
.unwrap_or(&self.path)
.to_string_lossy()
.into_owned();
let wstr: WString<utf16string::LE> = WString::from(&path);
let name = wstr.as_bytes();
log::trace!(
"put file to list: name_len {}, name {}",
name.len(),
&self.name
);
let flags = FLAGS_FD_SIZE
| FLAGS_FD_LAST_WRITE
| FLAGS_FD_ATTRIBUTES
| FLAGS_FD_PROGRESSUI
| FLAGS_FD_UNIX_MODE;
// flags, 4 bytes
buf.put_u32_le(flags);
// 32 bytes reserved
buf.put(&[0u8; 32][..]);
// file attributes, 4 bytes
buf.put_u32_le(file_attributes);
// NOTE: this is not used in windows
// in the specification, this is 16 bytes reserved
// lets use the last 4 bytes to store the file mode
//
// 12 bytes reserved
buf.put(&[0u8; 12][..]);
// file permissions, 4 bytes
buf.put_u32_le(self.perm);
// last write time, 8 bytes
buf.put_u64_le(win32_time);
// file size (high)
buf.put_u32_le(size_high);
// file size (low)
buf.put_u32_le(size_low);
// put name and padding to 520 bytes
let name_len = name.len();
buf.put(name);
buf.put(&vec![0u8; 520 - name_len][..]);
buf.to_vec()
}
#[inline]
pub fn load_handle(&mut self) -> Result<(), CliprdrError> {
if !self.is_dir && self.handle.is_none() {
let handle = std::fs::File::open(&self.path).map_err(|e| CliprdrError::FileError {
path: self.path.to_string_lossy().to_string(),
err: e,
})?;
let mut reader = BufReader::with_capacity(BLOCK_SIZE as usize * 2, handle);
reader.fill_buf().map_err(|e| CliprdrError::FileError {
path: self.path.to_string_lossy().to_string(),
err: e,
})?;
self.handle = Some(reader);
};
Ok(())
}
pub fn read_exact_at(&mut self, buf: &mut [u8], offset: u64) -> Result<(), CliprdrError> {
self.load_handle()?;
let Some(handle) = self.handle.as_mut() else {
return Err(CliprdrError::FileError {
path: self.path.to_string_lossy().to_string(),
err: std::io::Error::new(std::io::ErrorKind::NotFound, "file handle not found"),
});
};
if offset != self.offset.load(Ordering::Relaxed) {
handle
.seek(std::io::SeekFrom::Start(offset))
.map_err(|e| CliprdrError::FileError {
path: self.path.to_string_lossy().to_string(),
err: e,
})?;
}
handle
.read_exact(buf)
.map_err(|e| CliprdrError::FileError {
path: self.path.to_string_lossy().to_string(),
err: e,
})?;
let new_offset = offset + (buf.len() as u64);
self.offset.store(new_offset, Ordering::Relaxed);
// gc file handle
if new_offset >= self.size {
self.offset.store(0, Ordering::Relaxed);
self.handle = None;
}
Ok(())
}
}
pub(super) fn construct_file_list(paths: &[PathBuf]) -> Result<Vec<LocalFile>, CliprdrError> {
fn constr_file_lst(
relative_root: &Path,
path: &Path,
file_list: &mut Vec<LocalFile>,
visited: &mut HashSet<PathBuf>,
) -> Result<(), CliprdrError> {
// prevent fs loop
if visited.contains(path) {
return Ok(());
}
visited.insert(path.to_path_buf());
let local_file = LocalFile::try_open(relative_root, path)?;
file_list.push(local_file);
let mt = std::fs::metadata(path).map_err(|e| CliprdrError::FileError {
path: path.to_string_lossy().to_string(),
err: e,
})?;
if mt.is_dir() {
let dir = std::fs::read_dir(path).map_err(|e| CliprdrError::FileError {
path: path.to_string_lossy().to_string(),
err: e,
})?;
for entry in dir {
let entry = entry.map_err(|e| CliprdrError::FileError {
path: path.to_string_lossy().to_string(),
err: e,
})?;
let path = entry.path();
constr_file_lst(relative_root, &path, file_list, visited)?;
}
}
Ok(())
}
let mut file_list = Vec::new();
let mut visited = HashSet::new();
let relative_root = paths
.first()
.ok_or(CliprdrError::InvalidRequest {
description: "empty file list".to_string(),
})?
.parent()
.ok_or(CliprdrError::InvalidRequest {
description: "empty parent".to_string(),
})?
.to_path_buf();
for path in paths {
constr_file_lst(&relative_root, path, &mut file_list, &mut visited)?;
}
Ok(file_list)
}
#[cfg(test)]
mod file_list_test {
use std::{path::PathBuf, sync::atomic::AtomicU64};
use hbb_common::bytes::{BufMut, BytesMut};
use crate::{platform::unix::filetype::FileDescription, CliprdrError};
use super::LocalFile;
#[inline]
fn generate_tree(prefix: &str) -> Vec<LocalFile> {
// generate a tree of local files, no handles
// - /
// |- a.txt
// |- b
// |- c.txt
#[inline]
fn generate_file(path: &str, name: &str, is_dir: bool) -> LocalFile {
LocalFile {
relative_root: PathBuf::from("."),
path: PathBuf::from(path),
handle: None,
name: name.to_string(),
size: 0,
offset: AtomicU64::new(0),
last_write_time: std::time::SystemTime::UNIX_EPOCH,
read_only: false,
is_dir,
perm: 0o754,
hidden: false,
system: false,
archive: false,
normal: false,
}
}
let p = prefix;
let (r_path, a_path, b_path, c_path) = if !prefix.is_empty() {
(
p.to_string(),
format!("{}/a.txt", p),
format!("{}/b", p),
format!("{}/b/c.txt", p),
)
} else {
(
".".to_string(),
"a.txt".to_string(),
"b".to_string(),
"b/c.txt".to_string(),
)
};
let root = generate_file(&r_path, ".", true);
let a = generate_file(&a_path, "a.txt", false);
let b = generate_file(&b_path, "b", true);
let c = generate_file(&c_path, "c.txt", false);
vec![root, a, b, c]
}
fn as_bin_parse_test(prefix: &str) -> Result<(), CliprdrError> {
let tree = generate_tree(prefix);
let mut pdu = BytesMut::with_capacity(4 + 592 * tree.len());
pdu.put_u32_le(tree.len() as u32);
for file in tree {
pdu.put(file.as_bin().as_slice());
}
let parsed = FileDescription::parse_file_descriptors(pdu.to_vec(), 0)?;
assert_eq!(parsed.len(), 4);
if !prefix.is_empty() {
assert_eq!(parsed[0].name.to_str().unwrap(), format!("{}", prefix));
assert_eq!(
parsed[1].name.to_str().unwrap(),
format!("{}/a.txt", prefix)
);
assert_eq!(parsed[2].name.to_str().unwrap(), format!("{}/b", prefix));
assert_eq!(
parsed[3].name.to_str().unwrap(),
format!("{}/b/c.txt", prefix)
);
} else {
assert_eq!(parsed[0].name.to_str().unwrap(), ".");
assert_eq!(parsed[1].name.to_str().unwrap(), "a.txt");
assert_eq!(parsed[2].name.to_str().unwrap(), "b");
assert_eq!(parsed[3].name.to_str().unwrap(), "b/c.txt");
}
assert!(parsed[0].perm & 0o777 == 0o754);
assert!(parsed[1].perm & 0o777 == 0o754);
assert!(parsed[2].perm & 0o777 == 0o754);
assert!(parsed[3].perm & 0o777 == 0o754);
Ok(())
}
#[test]
fn test_parse_file_descriptors() -> Result<(), CliprdrError> {
as_bin_parse_test("")?;
as_bin_parse_test("/")?;
as_bin_parse_test("test")?;
as_bin_parse_test("/test")?;
Ok(())
}
}
@@ -0,0 +1,77 @@
use super::pasteboard_context::{PasteObserverInfo, TEMP_FILE_PREFIX};
use objc2::{
declare_class, msg_send_id, mutability,
rc::Id,
runtime::{NSObject, NSObjectProtocol},
ClassType, DeclaredClass,
};
use objc2_app_kit::{
NSPasteboard, NSPasteboardItem, NSPasteboardItemDataProvider, NSPasteboardType,
NSPasteboardTypeFileURL,
};
use objc2_foundation::NSString;
use std::{io::Result, sync::mpsc::Sender};
pub(super) struct Ivars {
task_info: PasteObserverInfo,
tx: Sender<Result<PasteObserverInfo>>,
}
declare_class!(
pub(super) struct PasteboardFileUrlProvider;
unsafe impl ClassType for PasteboardFileUrlProvider {
type Super = NSObject;
type Mutability = mutability::InteriorMutable;
const NAME: &'static str = "PasteboardFileUrlProvider";
}
impl DeclaredClass for PasteboardFileUrlProvider {
type Ivars = Ivars;
}
unsafe impl NSObjectProtocol for PasteboardFileUrlProvider {}
unsafe impl NSPasteboardItemDataProvider for PasteboardFileUrlProvider {
#[method(pasteboard:item:provideDataForType:)]
#[allow(non_snake_case)]
unsafe fn pasteboard_item_provideDataForType(
&self,
_pasteboard: Option<&NSPasteboard>,
item: &NSPasteboardItem,
r#type: &NSPasteboardType,
) {
if r#type == NSPasteboardTypeFileURL {
let path = format!("/tmp/{}{}", TEMP_FILE_PREFIX, uuid::Uuid::new_v4().to_string());
match std::fs::File::create(&path) {
Ok(_) => {
let url = format!("file:///{}", &path);
item.setString_forType(&NSString::from_str(&url), &NSPasteboardTypeFileURL);
let mut task_info = self.ivars().task_info.clone();
task_info.source_path = path;
self.ivars().tx.send(Ok(task_info)).ok();
}
Err(e) => {
self.ivars().tx.send(Err(e)).ok();
}
}
}
}
// #[method(pasteboardFinishedWithDataProvider:)]
// unsafe fn pasteboardFinishedWithDataProvider(&self, pasteboard: &NSPasteboard) {
// }
}
unsafe impl PasteboardFileUrlProvider {}
);
pub(super) fn create_pasteboard_file_url_provider(
task_info: PasteObserverInfo,
tx: Sender<Result<PasteObserverInfo>>,
) -> Id<PasteboardFileUrlProvider> {
let provider = PasteboardFileUrlProvider::alloc();
let provider = provider.set_ivars(Ivars { task_info, tx });
let provider: Id<PasteboardFileUrlProvider> = unsafe { msg_send_id![super(provider), init] };
provider
}
@@ -0,0 +1,14 @@
mod item_data_provider;
mod paste_observer;
mod paste_task;
pub mod pasteboard_context;
pub fn should_handle_msg(msg: &crate::ClipboardFile) -> bool {
matches!(
msg,
crate::ClipboardFile::FormatList { .. }
| crate::ClipboardFile::FormatDataResponse { .. }
| crate::ClipboardFile::FileContentsResponse { .. }
| crate::ClipboardFile::TryEmpty
)
}
Binary file not shown.

After

Width:  |  Height:  |  Size: 38 KiB

@@ -0,0 +1,179 @@
use super::pasteboard_context::PasteObserverInfo;
use fsevent::{self, StreamFlags};
use hbb_common::{bail, log, ResultType};
use std::{
sync::{
mpsc::{channel, Receiver, RecvTimeoutError, Sender},
Arc, Mutex,
},
thread,
time::Duration,
};
enum FseventControl {
Start,
Stop,
Exit,
}
struct FseventThreadInfo {
tx: Sender<FseventControl>,
handle: thread::JoinHandle<()>,
}
pub struct PasteObserver {
exit: Arc<Mutex<bool>>,
observer_info: Arc<Mutex<Option<PasteObserverInfo>>>,
tx_handle_fsevent_thread: Option<FseventThreadInfo>,
handle_observer_thread: Option<thread::JoinHandle<()>>,
}
impl Drop for PasteObserver {
fn drop(&mut self) {
*self.exit.lock().unwrap() = true;
if let Some(handle_observer_thread) = self.handle_observer_thread.take() {
handle_observer_thread.join().ok();
}
if let Some(tx_handle_fsevent_thread) = self.tx_handle_fsevent_thread.take() {
tx_handle_fsevent_thread.tx.send(FseventControl::Exit).ok();
tx_handle_fsevent_thread.handle.join().ok();
}
}
}
impl PasteObserver {
const OBSERVE_TIMEOUT: Duration = Duration::from_secs(30);
pub fn new() -> Self {
Self {
exit: Arc::new(Mutex::new(false)),
observer_info: Default::default(),
tx_handle_fsevent_thread: None,
handle_observer_thread: None,
}
}
pub fn init(&mut self, cb_pasted: fn(&PasteObserverInfo) -> ()) -> ResultType<()> {
let Some(home_dir) = dirs::home_dir() else {
bail!("No home dir is set, do not observe.");
};
let (tx_observer, rx_observer) = channel::<fsevent::Event>();
let handle_observer = Self::init_thread_observer(
self.exit.clone(),
self.observer_info.clone(),
rx_observer,
cb_pasted,
);
self.handle_observer_thread = Some(handle_observer);
let (tx_control, rx_control) = channel::<FseventControl>();
let handle_fsevent = Self::init_thread_fsevent(
home_dir.to_string_lossy().to_string(),
tx_observer,
rx_control,
);
self.tx_handle_fsevent_thread = Some(FseventThreadInfo {
tx: tx_control,
handle: handle_fsevent,
});
Ok(())
}
#[inline]
fn get_file_from_path(path: &String) -> String {
let last_slash = path.rfind('/').or_else(|| path.rfind('\\'));
match last_slash {
Some(index) => path[index + 1..].to_string(),
None => path.clone(),
}
}
fn init_thread_observer(
exit: Arc<Mutex<bool>>,
observer_info: Arc<Mutex<Option<PasteObserverInfo>>>,
rx_observer: Receiver<fsevent::Event>,
cb_pasted: fn(&PasteObserverInfo) -> (),
) -> thread::JoinHandle<()> {
thread::spawn(move || loop {
match rx_observer.recv_timeout(Duration::from_millis(300)) {
Ok(event) => {
if (event.flag & StreamFlags::ITEM_CREATED) != StreamFlags::NONE
&& (event.flag & StreamFlags::ITEM_REMOVED) == StreamFlags::NONE
&& (event.flag & StreamFlags::IS_FILE) != StreamFlags::NONE
{
let source_file = observer_info
.lock()
.unwrap()
.as_ref()
.map(|x| Self::get_file_from_path(&x.source_path));
if let Some(source_file) = source_file {
let file = Self::get_file_from_path(&event.path);
if source_file == file {
if let Some(observer_info) = observer_info.lock().unwrap().as_mut()
{
observer_info.target_path = event.path.clone();
cb_pasted(observer_info);
}
}
}
}
}
Err(_) => {
if *(exit.lock().unwrap()) {
break;
}
}
}
})
}
fn new_fsevent(home_dir: String, tx_observer: Sender<fsevent::Event>) -> fsevent::FsEvent {
let mut evt = fsevent::FsEvent::new(vec![home_dir.to_string()]);
evt.observe_async(tx_observer).ok();
evt
}
fn init_thread_fsevent(
home_dir: String,
tx_observer: Sender<fsevent::Event>,
rx_control: Receiver<FseventControl>,
) -> thread::JoinHandle<()> {
log::debug!("fsevent observe dir: {}", &home_dir);
thread::spawn(move || {
let mut fsevent = None;
loop {
match rx_control.recv_timeout(Self::OBSERVE_TIMEOUT) {
Ok(FseventControl::Start) => {
if fsevent.is_none() {
fsevent =
Some(Self::new_fsevent(home_dir.clone(), tx_observer.clone()));
}
}
Ok(FseventControl::Stop) | Err(RecvTimeoutError::Timeout) => {
let _ = fsevent.as_mut().map(|e| e.shutdown_observe());
fsevent = None;
}
Ok(FseventControl::Exit) | Err(RecvTimeoutError::Disconnected) => {
break;
}
}
}
log::info!("fsevent thread exit");
let _ = fsevent.as_mut().map(|e| e.shutdown_observe());
})
}
pub fn start(&mut self, observer_info: PasteObserverInfo) {
if let Some(tx_handle_fsevent_thread) = self.tx_handle_fsevent_thread.as_ref() {
self.observer_info.lock().unwrap().replace(observer_info);
tx_handle_fsevent_thread.tx.send(FseventControl::Start).ok();
}
}
pub fn stop(&mut self) {
if let Some(tx_handle_fsevent_thread) = &self.tx_handle_fsevent_thread {
self.observer_info = Default::default();
tx_handle_fsevent_thread.tx.send(FseventControl::Stop).ok();
}
}
}
@@ -0,0 +1,639 @@
use crate::{
platform::unix::{FileDescription, FileType, BLOCK_SIZE},
send_data, ClipboardFile, CliprdrError, ProgressPercent,
};
use hbb_common::{allow_err, log, tokio::time::Instant};
use std::{
cmp::min,
fs::{File, FileTimes},
io::{BufWriter, Write},
os::macos::fs::FileTimesExt,
path::{Path, PathBuf},
sync::{
mpsc::{Receiver, RecvTimeoutError},
Arc, Mutex,
},
thread,
time::{Duration, SystemTime},
};
const RECV_RETRY_TIMES: usize = 3;
const DOWNLOAD_EXTENSION: &str = "rddownload";
const RECEIVE_WAIT_TIMEOUT: Duration = Duration::from_millis(5_000);
// https://stackoverflow.com/a/15112784/1926020
// "1984-01-24 08:00:00 +0000"
const TIMESTAMP_FOR_FILE_PROGRESS_COMPLETED: u64 = 443779200;
const ATTR_PROGRESS_FRACTION_COMPLETED: &str = "com.apple.progress.fractionCompleted";
pub struct FileContentsResponse {
pub conn_id: i32,
pub msg_flags: i32,
pub stream_id: i32,
pub requested_data: Vec<u8>,
}
#[derive(Debug)]
struct PasteTaskProgress {
// Use list index to identify the file
// `list_index` is also used as the stream id
list_index: i32,
offset: u64,
total_size: u64,
current_size: u64,
last_sent_time: Instant,
download_file_index: i32,
download_file_size: u64,
download_file_path: String,
download_file_current_size: u64,
file_handle: Option<BufWriter<File>>,
error: Option<CliprdrError>,
is_canceled: bool,
}
struct PasteTaskHandle {
progress: PasteTaskProgress,
target_dir: PathBuf,
files: Vec<FileDescription>,
}
pub struct PasteTask {
exit: Arc<Mutex<bool>>,
handle: Arc<Mutex<Option<PasteTaskHandle>>>,
handle_worker: Option<thread::JoinHandle<()>>,
}
impl Drop for PasteTask {
fn drop(&mut self) {
*self.exit.lock().unwrap() = true;
if let Some(handle_worker) = self.handle_worker.take() {
handle_worker.join().ok();
}
}
}
impl PasteTask {
const INVALID_FILE_INDEX: i32 = -1;
pub fn new(rx_file_contents: Receiver<FileContentsResponse>) -> Self {
let exit = Arc::new(Mutex::new(false));
let handle = Arc::new(Mutex::new(None));
let handle_worker =
Self::init_worker_thread(exit.clone(), handle.clone(), rx_file_contents);
Self {
handle,
exit,
handle_worker: Some(handle_worker),
}
}
pub fn start(&mut self, target_dir: PathBuf, files: Vec<FileDescription>) {
let mut task_lock = self.handle.lock().unwrap();
if task_lock
.as_ref()
.map(|x| !x.is_finished())
.unwrap_or(false)
{
log::error!("Previous paste task is not finished, ignore new request.");
return;
}
let total_size = files.iter().map(|f| f.size).sum();
let mut task_handle = PasteTaskHandle {
progress: PasteTaskProgress {
list_index: -1,
offset: 0,
total_size,
current_size: 0,
last_sent_time: Instant::now(),
download_file_index: Self::INVALID_FILE_INDEX,
download_file_size: 0,
download_file_path: "".to_owned(),
download_file_current_size: 0,
file_handle: None,
error: None,
is_canceled: false,
},
target_dir,
files,
};
task_handle.update_next(0).ok();
if task_handle.is_finished() {
task_handle.on_finished();
} else {
if let Err(e) = task_handle.send_file_contents_request() {
log::error!("Failed to send file contents request, error: {}", &e);
task_handle.on_error(e);
}
}
*task_lock = Some(task_handle);
}
pub fn cancel(&self) {
let mut task_handle = self.handle.lock().unwrap();
if let Some(task_handle) = task_handle.as_mut() {
task_handle.progress.is_canceled = true;
task_handle.on_cancelled();
}
}
fn init_worker_thread(
exit: Arc<Mutex<bool>>,
handle: Arc<Mutex<Option<PasteTaskHandle>>>,
rx_file_contents: Receiver<FileContentsResponse>,
) -> thread::JoinHandle<()> {
thread::spawn(move || {
let mut retry_count = 0;
loop {
if *exit.lock().unwrap() {
break;
}
match rx_file_contents.recv_timeout(Duration::from_millis(300)) {
Ok(file_contents) => {
let mut task_lock = handle.lock().unwrap();
let Some(task_handle) = task_lock.as_mut() else {
continue;
};
if task_handle.is_finished() {
continue;
}
if file_contents.stream_id != task_handle.progress.list_index {
// ignore invalid stream id
continue;
} else if file_contents.msg_flags != 0x01 {
retry_count += 1;
if retry_count > RECV_RETRY_TIMES {
task_handle.progress.error = Some(CliprdrError::InvalidRequest {
description: format!(
"Failed to read file contents, stream id: {}, msg_flags: {}",
file_contents.stream_id,
file_contents.msg_flags
),
});
}
} else {
let resp_list_index = file_contents.stream_id;
let Some(file) = &task_handle.files.get(resp_list_index as usize)
else {
// unreachable
// Because `task_handle.progress.list_index >= task_handle.files.len()` should always be false
log::warn!(
"Invalid response list index: {}, file length: {}",
resp_list_index,
task_handle.files.len()
);
continue;
};
if file.conn_id != file_contents.conn_id {
// unreachable
// We still add log here to make sure we can see the error message when it happens.
log::error!(
"Invalid response conn id: {}, expected: {}",
file_contents.conn_id,
file.conn_id
);
continue;
}
if let Err(e) = task_handle.handle_file_contents_response(file_contents)
{
log::error!("Failed to handle file contents response: {}", &e);
task_handle.on_error(e);
}
}
if !task_handle.is_finished() {
if let Err(e) = task_handle.send_file_contents_request() {
log::error!("Failed to send file contents request: {}", &e);
task_handle.on_error(e);
}
} else {
retry_count = 0;
task_handle.on_finished();
}
}
Err(RecvTimeoutError::Timeout) => {
let mut task_lock = handle.lock().unwrap();
if let Some(task_handle) = task_lock.as_mut() {
if task_handle.check_receive_timemout() {
retry_count = 0;
task_handle.on_finished();
}
}
}
Err(RecvTimeoutError::Disconnected) => {
break;
}
}
}
})
}
pub fn is_finished(&self) -> bool {
self.handle
.lock()
.unwrap()
.as_ref()
.map(|handle| handle.is_finished())
.unwrap_or(true)
}
pub fn progress_percent(&self) -> Option<ProgressPercent> {
self.handle
.lock()
.unwrap()
.as_ref()
.map(|handle| handle.progress_percent())
}
}
impl PasteTaskHandle {
fn update_next(&mut self, size: u64) -> Result<(), CliprdrError> {
if self.is_finished() {
return Ok(());
}
self.progress.current_size += size;
let is_start = self.progress.list_index == -1;
if is_start || (self.progress.offset + size) >= self.progress.download_file_size {
if !is_start {
self.on_done();
}
for i in (self.progress.list_index + 1)..self.files.len() as i32 {
let Some(file_desc) = self.files.get(i as usize) else {
return Err(CliprdrError::InvalidRequest {
description: format!("Invalid file index: {}", i),
});
};
match file_desc.kind {
FileType::File => {
if file_desc.size == 0 {
if let Some(new_file_path) =
Self::get_new_filename(&self.target_dir, file_desc)
{
if let Ok(f) = std::fs::File::create(&new_file_path) {
f.set_len(0).ok();
Self::set_file_metadata(&f, file_desc);
}
};
} else {
self.progress.list_index = i;
self.progress.offset = 0;
self.open_new_writer()?;
break;
}
}
FileType::Directory => {
let path = self.target_dir.join(&file_desc.name);
if !path.exists() {
std::fs::create_dir_all(path).ok();
}
}
FileType::Symlink => {
// to-do: handle symlink
}
}
}
} else {
self.progress.offset += size;
self.progress.download_file_current_size += size;
self.update_progress_completed(None);
}
if self.progress.file_handle.is_none() {
self.progress.list_index = self.files.len() as i32;
self.progress.offset = 0;
self.progress.download_file_size = 0;
self.progress.download_file_current_size = 0;
}
Ok(())
}
fn start_progress_completed(&self) {
if let Some(file) = self.progress.file_handle.as_ref() {
let creation_time =
SystemTime::UNIX_EPOCH + Duration::from_secs(TIMESTAMP_FOR_FILE_PROGRESS_COMPLETED);
file.get_ref()
.set_times(FileTimes::new().set_created(creation_time))
.ok();
xattr::set(
&self.progress.download_file_path,
ATTR_PROGRESS_FRACTION_COMPLETED,
"0.0".as_bytes(),
)
.ok();
}
}
fn update_progress_completed(&mut self, fraction_completed: Option<f64>) {
let fraction_completed = fraction_completed.unwrap_or_else(|| {
let current_size = self.progress.download_file_current_size as f64;
let total_size = self.progress.download_file_size as f64;
if total_size > 0.0 {
current_size / total_size
} else {
1.0
}
});
xattr::set(
&self.progress.download_file_path,
ATTR_PROGRESS_FRACTION_COMPLETED,
&fraction_completed.to_string().as_bytes(),
)
.ok();
}
#[inline]
fn remove_progress_completed(path: &str) {
if !path.is_empty() {
xattr::remove(path, ATTR_PROGRESS_FRACTION_COMPLETED).ok();
}
}
fn open_new_writer(&mut self) -> Result<(), CliprdrError> {
let Some(file) = &self.files.get(self.progress.list_index as usize) else {
return Err(CliprdrError::InvalidRequest {
description: format!(
"Invalid file index: {}, file count: {}",
self.progress.list_index,
self.files.len()
),
});
};
let original_file_path = self
.target_dir
.join(&file.name)
.to_string_lossy()
.to_string();
let Some(download_file_path) = Self::get_first_filename(
format!("{}.{}", original_file_path, DOWNLOAD_EXTENSION),
file.kind,
) else {
return Err(CliprdrError::CommonError {
description: format!("Failed to get download file path: {}", original_file_path),
});
};
let Some(download_path_parent) = Path::new(&download_file_path).parent() else {
return Err(CliprdrError::CommonError {
description: format!(
"Failed to get parent of the download file path: {}",
original_file_path
),
});
};
if !download_path_parent.exists() {
if let Err(e) = std::fs::create_dir_all(download_path_parent) {
return Err(CliprdrError::FileError {
path: download_path_parent.to_string_lossy().to_string(),
err: e,
});
}
}
match std::fs::File::create(&download_file_path) {
Ok(handle) => {
let writer = BufWriter::with_capacity(BLOCK_SIZE as usize * 2, handle);
self.progress.download_file_index = self.progress.list_index;
self.progress.download_file_size = file.size;
self.progress.download_file_path = download_file_path;
self.progress.download_file_current_size = 0;
self.progress.file_handle = Some(writer);
self.start_progress_completed();
}
Err(e) => {
self.progress.error = Some(CliprdrError::FileError {
path: download_file_path,
err: e,
});
}
};
Ok(())
}
fn get_first_filename(path: String, r#type: FileType) -> Option<String> {
let p = Path::new(&path);
if !p.exists() {
return Some(path);
} else {
for i in 1..9999999 {
let new_path = match r#type {
FileType::File => {
if let Some(ext) = p.extension() {
let new_name = format!(
"{}-{}.{}",
p.file_stem().unwrap_or_default().to_string_lossy(),
i,
ext.to_string_lossy()
);
p.with_file_name(new_name).to_string_lossy().to_string()
} else {
format!("{} ({})", path, i)
}
}
FileType::Directory => format!("{} ({})", path, i),
FileType::Symlink => {
// to-do: handle symlink
return None;
}
};
if !Path::new(&new_path).exists() {
return Some(new_path);
}
}
}
// unreachable
None
}
fn progress_percent(&self) -> ProgressPercent {
let percent = self.progress.current_size as f64 / self.progress.total_size as f64;
ProgressPercent {
percent,
is_canceled: self.progress.is_canceled,
is_failed: self.progress.error.is_some(),
}
}
fn is_finished(&self) -> bool {
self.progress.is_canceled
|| self.progress.error.is_some()
|| self.progress.list_index >= self.files.len() as i32
}
fn check_receive_timemout(&mut self) -> bool {
if !self.is_finished() {
if self.progress.last_sent_time.elapsed() > RECEIVE_WAIT_TIMEOUT {
self.progress.error = Some(CliprdrError::InvalidRequest {
description: "Failed to read file contents".to_string(),
});
return true;
}
}
false
}
fn on_finished(&mut self) {
if self.progress.error.is_some() {
self.on_cancelled();
} else {
self.on_done();
}
if self.progress.current_size != self.progress.total_size {
self.progress.error = Some(CliprdrError::InvalidRequest {
description: "Failed to download all files".to_string(),
});
}
}
fn on_error(&mut self, error: CliprdrError) {
self.progress.error = Some(error);
self.on_cancelled();
}
fn on_cancelled(&mut self) {
self.progress.file_handle = None;
std::fs::remove_file(&self.progress.download_file_path).ok();
}
fn on_done(&mut self) {
self.update_progress_completed(Some(1.0));
Self::remove_progress_completed(&self.progress.download_file_path);
let Some(file) = self.progress.file_handle.as_mut() else {
return;
};
if self.progress.download_file_index == PasteTask::INVALID_FILE_INDEX {
return;
}
if let Err(e) = file.flush() {
log::error!("Failed to flush file: {:?}", e);
}
self.progress.file_handle = None;
let Some(file_desc) = self.files.get(self.progress.download_file_index as usize) else {
// unreachable
log::error!(
"Failed to get file description: {}",
self.progress.download_file_index
);
return;
};
let Some(rename_to_path) = Self::get_new_filename(&self.target_dir, file_desc) else {
return;
};
match std::fs::rename(&self.progress.download_file_path, &rename_to_path) {
Ok(_) => Self::set_file_metadata2(&rename_to_path, file_desc),
Err(e) => {
log::error!("Failed to rename file: {:?}", e);
}
}
self.progress.download_file_path = "".to_owned();
self.progress.download_file_index = PasteTask::INVALID_FILE_INDEX;
}
fn get_new_filename(target_dir: &PathBuf, file_desc: &FileDescription) -> Option<String> {
let mut rename_to_path = target_dir
.join(&file_desc.name)
.to_string_lossy()
.to_string();
if Path::new(&rename_to_path).exists() {
let Some(new_path) = Self::get_first_filename(rename_to_path.clone(), file_desc.kind)
else {
log::error!("Failed to get new file name: {}", &rename_to_path);
return None;
};
rename_to_path = new_path;
}
Some(rename_to_path)
}
#[inline]
fn set_file_metadata(f: &File, file_desc: &FileDescription) {
let times = FileTimes::new()
.set_accessed(file_desc.atime)
.set_modified(file_desc.last_modified)
.set_created(file_desc.creation_time);
f.set_times(times).ok();
}
#[inline]
fn set_file_metadata2(path: &str, file_desc: &FileDescription) {
let times = FileTimes::new()
.set_accessed(file_desc.atime)
.set_modified(file_desc.last_modified)
.set_created(file_desc.creation_time);
File::options()
.write(true)
.open(path)
.map(|f| f.set_times(times))
.ok();
}
fn send_file_contents_request(&mut self) -> Result<(), CliprdrError> {
if self.is_finished() {
return Ok(());
}
let stream_id = self.progress.list_index;
let list_index = self.progress.list_index;
let Some(file) = &self.files.get(list_index as usize) else {
// unreachable
return Err(CliprdrError::InvalidRequest {
description: format!("Invalid file index: {}", list_index),
});
};
let cb_requested = min(BLOCK_SIZE as u64, file.size - self.progress.offset);
let conn_id = file.conn_id;
let (n_position_high, n_position_low) = (
(self.progress.offset >> 32) as i32,
(self.progress.offset & (u32::MAX as u64)) as i32,
);
let request = ClipboardFile::FileContentsRequest {
stream_id,
list_index,
dw_flags: 2,
n_position_low,
n_position_high,
cb_requested: cb_requested as _,
have_clip_data_id: false,
clip_data_id: 0,
};
allow_err!(send_data(conn_id, request));
self.progress.last_sent_time = Instant::now();
Ok(())
}
fn handle_file_contents_response(
&mut self,
file_contents: FileContentsResponse,
) -> Result<(), CliprdrError> {
if let Some(file) = self.progress.file_handle.as_mut() {
let data = file_contents.requested_data.as_slice();
let mut write_len = 0;
while write_len < data.len() {
match file.write(&data[write_len..]) {
Ok(len) => {
write_len += len;
}
Err(e) => {
return Err(CliprdrError::FileError {
path: self.progress.download_file_path.clone(),
err: e,
});
}
}
}
self.update_next(write_len as _)?;
} else {
return Err(CliprdrError::FileError {
path: self.progress.download_file_path.clone(),
err: std::io::Error::new(std::io::ErrorKind::NotFound, "file handle is not opened"),
});
}
Ok(())
}
}
@@ -0,0 +1,460 @@
use super::{
item_data_provider::create_pasteboard_file_url_provider,
paste_observer::PasteObserver,
paste_task::{FileContentsResponse, PasteTask},
};
use crate::{
platform::unix::{
filetype::FileDescription, FILECONTENTS_FORMAT_NAME, FILEDESCRIPTORW_FORMAT_NAME,
},
send_data, ClipboardFile, CliprdrError, CliprdrServiceContext, ProgressPercent,
};
use hbb_common::{allow_err, bail, log, ResultType};
use objc2::{msg_send_id, rc::autoreleasepool, rc::Id, runtime::ProtocolObject, ClassType};
use objc2_app_kit::{NSPasteboard, NSPasteboardTypeFileURL};
use objc2_foundation::{NSArray, NSString};
use std::{
io,
path::Path,
sync::{
mpsc::{channel, Receiver, RecvTimeoutError, Sender},
Arc, Mutex,
},
thread,
time::Duration,
};
lazy_static::lazy_static! {
static ref PASTE_OBSERVER_INFO: Arc<Mutex<Option<PasteObserverInfo>>> = Default::default();
}
pub const TEMP_FILE_PREFIX: &str = ".rustdesk_";
#[derive(Default, Debug, Clone, PartialEq)]
pub(super) struct PasteObserverInfo {
pub file_descriptor_id: i32,
pub conn_id: i32,
pub source_path: String,
pub target_path: String,
}
impl PasteObserverInfo {
fn exit_msg() -> Self {
Self::default()
}
}
struct ContextInfo {
tx: Sender<io::Result<PasteObserverInfo>>,
handle: thread::JoinHandle<()>,
}
pub struct PasteboardContext {
pasteboard: Id<NSPasteboard>,
observer: Arc<Mutex<PasteObserver>>,
tx_handle: Option<ContextInfo>,
tx_remove_file: Option<Sender<String>>,
remove_file_handle: Option<thread::JoinHandle<()>>,
tx_paste_task: Sender<FileContentsResponse>,
paste_task: Arc<Mutex<PasteTask>>,
}
unsafe impl Send for PasteboardContext {}
unsafe impl Sync for PasteboardContext {}
impl Drop for PasteboardContext {
fn drop(&mut self) {
self.observer.lock().unwrap().stop();
if let Some(tx_handle) = self.tx_handle.take() {
if tx_handle.tx.send(Ok(PasteObserverInfo::exit_msg())).is_ok() {
tx_handle.handle.join().ok();
}
}
}
}
impl CliprdrServiceContext for PasteboardContext {
fn set_is_stopped(&mut self) -> Result<(), CliprdrError> {
Ok(())
}
fn empty_clipboard(&mut self, conn_id: i32) -> Result<bool, CliprdrError> {
Ok(self.empty_clipboard_(conn_id))
}
fn server_clip_file(&mut self, conn_id: i32, msg: ClipboardFile) -> Result<(), CliprdrError> {
self.server_clip_file_(conn_id, msg)
}
fn get_progress_percent(&self) -> Option<ProgressPercent> {
self.paste_task.lock().unwrap().progress_percent()
}
fn cancel(&mut self) {
self.paste_task.lock().unwrap().cancel();
}
}
impl PasteboardContext {
fn init(&mut self) {
let (tx_remove_file, rx_remove_file) = channel();
let handle_remove_file = Self::init_thread_remove_file(rx_remove_file);
self.tx_remove_file = Some(tx_remove_file.clone());
self.remove_file_handle = Some(handle_remove_file);
let (tx, rx) = channel();
let observer: Arc<Mutex<PasteObserver>> = self.observer.clone();
let handle = Self::init_thread_observer(tx_remove_file, rx, observer);
self.tx_handle = Some(ContextInfo { tx, handle });
}
fn init_thread_observer(
tx_remove_file: Sender<String>,
rx: Receiver<io::Result<PasteObserverInfo>>,
observer: Arc<Mutex<PasteObserver>>,
) -> thread::JoinHandle<()> {
let exit_msg = PasteObserverInfo::exit_msg();
thread::spawn(move || loop {
match rx.recv() {
Ok(Ok(task_info)) => {
if task_info == exit_msg {
log::debug!("pasteboard item data provider: exit");
break;
}
tx_remove_file.send(task_info.source_path.clone()).ok();
observer.lock().unwrap().start(task_info);
}
Ok(Err(e)) => {
log::error!("pasteboard item data provider, inner error: {e}");
}
Err(e) => {
log::error!("pasteboard item data provider, error: {e}");
break;
}
}
})
}
fn init_thread_remove_file(rx: Receiver<String>) -> thread::JoinHandle<()> {
thread::spawn(move || {
let mut cur_file: Option<String> = None;
loop {
match rx.recv_timeout(Duration::from_secs(30)) {
Ok(path) => {
if let Some(file) = cur_file.take() {
if !file.is_empty() {
std::fs::remove_file(&file).ok();
}
}
if !path.is_empty() {
cur_file = Some(path);
}
}
Err(e) => {
if let Some(file) = cur_file.take() {
if !file.is_empty() {
std::fs::remove_file(&file).ok();
}
}
if e == RecvTimeoutError::Disconnected {
break;
}
}
}
}
})
}
// Just removing the file can also make paste option in the context menu disappear.
fn empty_clipboard_(&mut self, _conn_id: i32) -> bool {
self.tx_remove_file
.as_ref()
.map(|tx| tx.send("".to_string()).ok());
true
}
fn temp_files_count() -> usize {
let mut count = 0;
if let Ok(entries) = std::fs::read_dir("/tmp") {
for entry in entries {
if let Ok(entry) = entry {
let path = entry.path();
if path.is_file() {
if let Some(file_name) = path.file_name() {
if let Some(file_name_str) = file_name.to_str() {
if file_name_str.starts_with(TEMP_FILE_PREFIX) {
count += 1;
}
}
}
}
}
}
}
count
}
fn server_clip_file_(&mut self, conn_id: i32, msg: ClipboardFile) -> Result<(), CliprdrError> {
match msg {
ClipboardFile::FormatList { format_list } => {
let temp_files = Self::temp_files_count();
if temp_files >= 3 {
// The temp files should be 0 or 1 in normal case.
// We should not continue to paste files if there are more than 3 temp files.
return Err(CliprdrError::CommonError {
description: format!(
"too many temp files, current: {}, limit: {}",
temp_files, 3
),
});
}
let task_lock = self.paste_task.lock().unwrap();
if !task_lock.is_finished() {
return Err(CliprdrError::CommonError {
description: "previous file paste task is not finished".to_string(),
});
}
self.handle_format_list(conn_id, format_list)?;
}
ClipboardFile::FormatDataResponse {
msg_flags,
format_data,
} => {
self.handle_format_data_response(conn_id, msg_flags, format_data)?;
}
ClipboardFile::FileContentsResponse {
msg_flags,
stream_id,
requested_data,
} => {
self.handle_file_contents_response(conn_id, msg_flags, stream_id, requested_data)?;
}
ClipboardFile::TryEmpty => self.handle_try_empty(conn_id),
_ => {}
}
Ok(())
}
fn handle_format_list(
&self,
conn_id: i32,
format_list: Vec<(i32, String)>,
) -> Result<(), CliprdrError> {
if let Some(tx_handle) = self.tx_handle.as_ref() {
if !format_list
.iter()
.find(|(_, name)| name == FILECONTENTS_FORMAT_NAME)
.map(|(id, _)| *id)
.is_some()
{
return Err(CliprdrError::CommonError {
description: "no file contents format found".to_string(),
});
};
let Some(file_descriptor_id) = format_list
.iter()
.find(|(_, name)| name == FILEDESCRIPTORW_FORMAT_NAME)
.map(|(id, _)| *id)
else {
return Err(CliprdrError::CommonError {
description: "no file descriptor format found".to_string(),
});
};
autoreleasepool(|_| self.set_clipboard_item(tx_handle, conn_id, file_descriptor_id))?;
} else {
return Err(CliprdrError::CommonError {
description: "pasteboard context is not inited".to_string(),
});
}
Ok(())
}
fn set_clipboard_item(
&self,
tx_handle: &ContextInfo,
conn_id: i32,
file_descriptor_id: i32,
) -> Result<(), CliprdrError> {
let tx = tx_handle.tx.clone();
let provider = create_pasteboard_file_url_provider(
PasteObserverInfo {
file_descriptor_id,
conn_id,
source_path: "".to_string(),
target_path: "".to_string(),
},
tx,
);
unsafe {
let types = NSArray::from_vec(vec![NSString::from_str(
&NSPasteboardTypeFileURL.to_string(),
)]);
let item = objc2_app_kit::NSPasteboardItem::new();
item.setDataProvider_forTypes(&ProtocolObject::from_id(provider), &types);
self.pasteboard.clearContents();
if !self
.pasteboard
.writeObjects(&Id::cast(NSArray::from_vec(vec![item])))
{
return Err(CliprdrError::CommonError {
description: "failed to write objects".to_string(),
});
}
}
Ok(())
}
fn handle_format_data_response(
&self,
conn_id: i32,
msg_flags: i32,
format_data: Vec<u8>,
) -> Result<(), CliprdrError> {
log::debug!("handle format data response, msg_flags: {msg_flags}");
if msg_flags != 0x1 {
// return failure message?
}
let mut task_lock = self.paste_task.lock().unwrap();
let target_dir = PASTE_OBSERVER_INFO
.lock()
.unwrap()
.as_ref()
.map(|task| task.target_path.clone());
// unreachable in normal case
let Some(target_dir) = target_dir.as_ref().map(|d| Path::new(d).parent()).flatten() else {
return Err(CliprdrError::CommonError {
description: "failed to get parent path".to_string(),
});
};
// unreachable in normal case
if !target_dir.exists() {
return Err(CliprdrError::CommonError {
description: "target path does not exist".to_string(),
});
}
let target_dir = target_dir.to_owned();
match FileDescription::parse_file_descriptors(format_data, conn_id) {
Ok(files) => {
task_lock.start(target_dir, files);
Ok(())
}
Err(e) => {
PASTE_OBSERVER_INFO
.lock()
.unwrap()
.replace(PasteObserverInfo::default());
Err(e)
}
}
}
fn handle_file_contents_response(
&self,
conn_id: i32,
msg_flags: i32,
stream_id: i32,
requested_data: Vec<u8>,
) -> Result<(), CliprdrError> {
log::debug!("handle file contents response");
self.tx_paste_task
.send(FileContentsResponse {
conn_id,
msg_flags,
stream_id,
requested_data,
})
.ok();
Ok(())
}
fn handle_try_empty(&mut self, conn_id: i32) {
log::debug!("empty_clipboard called");
let ret = self.empty_clipboard_(conn_id);
log::debug!(
"empty_clipboard called, conn_id {}, return {}",
conn_id,
ret
);
}
}
fn handle_paste_result(task_info: &PasteObserverInfo) {
log::info!(
"file {} is pasted to {}",
&task_info.source_path,
&task_info.target_path
);
if Path::new(&task_info.target_path).parent().is_none() {
log::error!(
"failed to get parent path of {}, no need to perform pasting",
&task_info.target_path
);
return;
}
PASTE_OBSERVER_INFO
.lock()
.unwrap()
.replace(task_info.clone());
// to-do: add a timeout to clear data in `PASTE_OBSERVER_INFO`.
std::fs::remove_file(&task_info.source_path).ok();
std::fs::remove_file(&task_info.target_path).ok();
let data = ClipboardFile::FormatDataRequest {
requested_format_id: task_info.file_descriptor_id,
};
allow_err!(send_data(task_info.conn_id as _, data));
}
#[inline]
pub fn create_pasteboard_context() -> ResultType<Box<PasteboardContext>> {
let pasteboard: Option<Id<NSPasteboard>> =
unsafe { msg_send_id![NSPasteboard::class(), generalPasteboard] };
let Some(pasteboard) = pasteboard else {
bail!("failed to get general pasteboard");
};
let mut observer = PasteObserver::new();
observer.init(handle_paste_result)?;
let (tx, rx) = channel();
let mut context = Box::new(PasteboardContext {
pasteboard,
observer: Arc::new(Mutex::new(observer)),
tx_handle: None,
tx_remove_file: None,
remove_file_handle: None,
tx_paste_task: tx,
paste_task: Arc::new(Mutex::new(PasteTask::new(rx))),
});
context.init();
Ok(context)
}
#[cfg(test)]
mod tests {
#[test]
fn test_temp_files_count() {
let mut c = super::PasteboardContext::temp_files_count();
let mut created_files = vec![];
for _ in 0..10 {
let path = format!(
"/tmp/{}{}",
super::TEMP_FILE_PREFIX,
uuid::Uuid::new_v4().to_string()
);
if std::fs::File::create(&path).is_ok() {
created_files.push(path);
c += 1;
}
}
assert_eq!(c, super::PasteboardContext::temp_files_count());
// Clean up the created files.
for file in created_files {
std::fs::remove_file(&file).ok();
}
}
}
+58
View File
@@ -0,0 +1,58 @@
use dashmap::DashMap;
use lazy_static::lazy_static;
mod filetype;
pub use filetype::{FileDescription, FileType};
/// use FUSE for file pasting on these platforms
#[cfg(target_os = "linux")]
pub mod fuse;
#[cfg(target_os = "macos")]
pub mod macos;
pub mod local_file;
pub mod serv_files;
/// has valid file attributes
pub const FLAGS_FD_ATTRIBUTES: u32 = 0x04;
/// has valid file size
pub const FLAGS_FD_SIZE: u32 = 0x40;
/// has valid last write time
pub const FLAGS_FD_LAST_WRITE: u32 = 0x20;
/// show progress
pub const FLAGS_FD_PROGRESSUI: u32 = 0x4000;
/// transferred from unix, contains file mode
/// P.S. this flag is not used in windows
pub const FLAGS_FD_UNIX_MODE: u32 = 0x08;
// not actual format id, just a placeholder
pub const FILEDESCRIPTOR_FORMAT_ID: i32 = 49334;
pub const FILEDESCRIPTORW_FORMAT_NAME: &str = "FileGroupDescriptorW";
// not actual format id, just a placeholder
pub const FILECONTENTS_FORMAT_ID: i32 = 49267;
pub const FILECONTENTS_FORMAT_NAME: &str = "FileContents";
/// block size for fuse, align to our asynchronic request size over FileContentsRequest.
pub(crate) const BLOCK_SIZE: u32 = 4 * 1024 * 1024;
// begin of epoch used by microsoft
// 1601-01-01 00:00:00 + LDAP_EPOCH_DELTA*(100 ns) = 1970-01-01 00:00:00
const LDAP_EPOCH_DELTA: u64 = 116444772610000000;
lazy_static! {
static ref REMOTE_FORMAT_MAP: DashMap<i32, String> = DashMap::from_iter(
[
(
FILEDESCRIPTOR_FORMAT_ID,
FILEDESCRIPTORW_FORMAT_NAME.to_string()
),
(FILECONTENTS_FORMAT_ID, FILECONTENTS_FORMAT_NAME.to_string())
]
.iter()
.cloned()
);
}
#[inline]
pub fn get_local_format(remote_id: i32) -> Option<String> {
REMOTE_FORMAT_MAP.get(&remote_id).map(|s| s.clone())
}
@@ -0,0 +1,271 @@
use super::local_file::LocalFile;
use crate::{platform::unix::local_file::construct_file_list, ClipboardFile, CliprdrError};
use hbb_common::{
bytes::{BufMut, BytesMut},
log,
};
use parking_lot::Mutex;
use std::{path::PathBuf, sync::Arc, usize};
lazy_static::lazy_static! {
// local files are cached, this value should not be changed when copying files
// Because `CliprdrFileContentsRequest` only contains the index of the file in the list.
// We need to keep the file list in the same order as the remote side.
// We may add a `FileId` field to `CliprdrFileContentsRequest` in the future.
static ref CLIP_FILES: Arc<Mutex<ClipFiles>> = Default::default();
}
#[derive(Debug)]
enum FileContentsRequest {
Size {
stream_id: i32,
file_idx: usize,
},
Range {
stream_id: i32,
file_idx: usize,
offset: u64,
length: u64,
},
}
#[derive(Default)]
struct ClipFiles {
files: Vec<String>,
file_list: Vec<LocalFile>,
first_file_index: usize,
files_pdu: Vec<u8>,
}
impl ClipFiles {
fn clear(&mut self) {
self.files.clear();
self.file_list.clear();
self.first_file_index = usize::MAX;
self.files_pdu.clear();
}
fn sync_files(&mut self, clipboard_files: &[String]) -> Result<(), CliprdrError> {
let clipboard_paths = clipboard_files
.iter()
.map(|s| PathBuf::from(s))
.collect::<Vec<_>>();
self.file_list = construct_file_list(&clipboard_paths)?;
self.first_file_index = self
.file_list
.iter()
.position(|f| !f.path.is_dir())
.unwrap_or(usize::MAX);
self.files = clipboard_files.to_vec();
Ok(())
}
fn build_file_list_pdu(&mut self) {
let mut data = BytesMut::with_capacity(4 + 592 * self.file_list.len());
data.put_u32_le(self.file_list.len() as u32);
for file in self.file_list.iter() {
data.put(file.as_bin().as_slice());
}
self.files_pdu = data.to_vec()
}
fn get_files_for_audit(&self, request: &FileContentsRequest) -> Option<ClipboardFile> {
if let FileContentsRequest::Range {
file_idx, offset, ..
} = request
{
if *file_idx == self.first_file_index && *offset == 0 {
let files: Vec<(String, u64)> = self
.file_list
.iter()
.filter_map(|f| {
if f.path.is_file() {
Some((f.path.to_string_lossy().to_string(), f.size))
} else {
None
}
})
.collect::<_>();
if files.is_empty() {
return None;
} else {
return Some(ClipboardFile::Files { files });
}
}
}
None
}
fn serve_file_contents(
&mut self,
conn_id: i32,
request: FileContentsRequest,
) -> Result<ClipboardFile, CliprdrError> {
let (file_idx, file_contents_resp) = match request {
FileContentsRequest::Size {
stream_id,
file_idx,
} => {
log::debug!("file contents (size) requested from conn: {}", conn_id);
let Some(file) = self.file_list.get(file_idx) else {
log::error!(
"invalid file index {} requested from conn: {}",
file_idx,
conn_id
);
return Err(CliprdrError::InvalidRequest {
description: format!(
"invalid file index {} requested from conn: {}",
file_idx, conn_id
),
});
};
log::debug!(
"conn {} requested file-{}: {}",
conn_id,
file_idx,
file.name
);
let size = file.size;
(
file_idx,
ClipboardFile::FileContentsResponse {
msg_flags: 0x1,
stream_id,
requested_data: size.to_le_bytes().to_vec(),
},
)
}
FileContentsRequest::Range {
stream_id,
file_idx,
offset,
length,
} => {
log::debug!(
"file contents (range from {} length {}) request from conn: {}",
offset,
length,
conn_id
);
let Some(file) = self.file_list.get_mut(file_idx) else {
log::error!(
"invalid file index {} requested from conn: {}",
file_idx,
conn_id
);
return Err(CliprdrError::InvalidRequest {
description: format!(
"invalid file index {} requested from conn: {}",
file_idx, conn_id
),
});
};
log::debug!(
"conn {} requested file-{}: {}",
conn_id,
file_idx,
file.name
);
if offset > file.size {
log::error!("invalid reading offset requested from conn: {}", conn_id);
return Err(CliprdrError::InvalidRequest {
description: format!(
"invalid reading offset requested from conn: {}",
conn_id
),
});
}
let read_size = if offset + length > file.size {
file.size - offset
} else {
length
};
let mut buf = vec![0u8; read_size as usize];
file.read_exact_at(&mut buf, offset)?;
(
file_idx,
ClipboardFile::FileContentsResponse {
msg_flags: 0x1,
stream_id,
requested_data: buf,
},
)
}
};
log::debug!("file contents sent to conn: {}", conn_id);
// hot reload next file
for next_file in self.file_list.iter_mut().skip(file_idx + 1) {
if !next_file.is_dir {
next_file.load_handle()?;
break;
}
}
Ok(file_contents_resp)
}
}
#[inline]
pub fn clear_files() {
CLIP_FILES.lock().clear();
}
pub fn read_file_contents(
conn_id: i32,
stream_id: i32,
list_index: i32,
dw_flags: i32,
n_position_low: i32,
n_position_high: i32,
cb_requested: i32,
) -> Vec<Result<ClipboardFile, CliprdrError>> {
let fcr = if dw_flags == 0x1 {
FileContentsRequest::Size {
stream_id,
file_idx: list_index as usize,
}
} else if dw_flags == 0x2 {
let offset = (n_position_high as u64) << 32 | n_position_low as u64;
let length = cb_requested as u64;
FileContentsRequest::Range {
stream_id,
file_idx: list_index as usize,
offset,
length,
}
} else {
return vec![Err(CliprdrError::InvalidRequest {
description: format!("got invalid FileContentsRequest, dw_flats: {dw_flags}"),
})];
};
let mut clip_files = CLIP_FILES.lock();
let mut res = vec![];
if let Some(files_res) = clip_files.get_files_for_audit(&fcr) {
res.push(Ok(files_res));
}
res.push(clip_files.serve_file_contents(conn_id, fcr));
res
}
pub fn sync_files(files: &[String]) -> Result<(), CliprdrError> {
let mut files_lock = CLIP_FILES.lock();
if files_lock.files == files {
return Ok(());
}
files_lock.sync_files(files)?;
Ok(files_lock.build_file_list_pdu())
}
pub fn get_file_list_pdu() -> Vec<u8> {
CLIP_FILES.lock().files_pdu.clone()
}
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff
@@ -0,0 +1 @@
* text=auto
+14
View File
@@ -0,0 +1,14 @@
.DS_Store
# Generated by Cargo
# will have compiled files and executables
/target/
# Remove Cargo.lock from gitignore if creating an executable, leave it for libraries
# More information here http://doc.crates.io/guide.html#cargotoml-vs-cargolock
Cargo.lock
# RustFmt files
**/*.rs.bk
# intellij
.idea
+15
View File
@@ -0,0 +1,15 @@
language: rust
rust:
- stable
- beta
- nightly
matrix:
allow_failures:
- rust: nightly
before_install:
- if [ "$TRAVIS_OS_NAME" == "linux" ]; then sudo apt-get -qq update; fi
- if [ "$TRAVIS_OS_NAME" == "linux" ]; then sudo apt-get install -y libxdo-dev; fi
os:
- linux
- osx
+13
View File
@@ -0,0 +1,13 @@
{
"version": "0.2.0",
"configurations": [
{
"name": "Debug",
"type": "gdb",
"request": "launch",
"target": "./target/debug/examples/keyboard",
"cwd": "${workspaceRoot}"
}
]
}
+44
View File
@@ -0,0 +1,44 @@
[package]
name = "enigo"
version = "0.0.14"
authors = ["Dustin Bensing <dustin.bensing@googlemail.com>"]
edition = "2018"
build = "build.rs"
description = "Enigo lets you control your mouse and keyboard in an abstract way on different operating systems (currently only Linux, macOS, Win Redox and *BSD planned)"
documentation = "https://docs.rs/enigo/"
homepage = "https://github.com/enigo-rs/enigo"
repository = "https://github.com/enigo-rs/enigo"
readme = "README.md"
keywords = ["input", "mouse", "testing", "keyboard", "automation"]
categories = ["development-tools::testing", "api-bindings", "hardware-support"]
license = "MIT"
[badges]
travis-ci = { repository = "enigo-rs/enigo" }
appveyor = { repository = "pythoneer/enigo-85xiy" }
[dependencies]
serde = { version = "1.0", optional = true }
serde_derive = { version = "1.0", optional = true }
log = "0.4"
rdev = { git = "https://github.com/rustdesk-org/rdev" }
tfc = { git = "https://github.com/rustdesk-org/The-Fat-Controller", branch = "history/rebase_upstream_20240722" }
hbb_common = { path = "../hbb_common" }
[features]
with_serde = ["serde", "serde_derive"]
[target.'cfg(target_os = "windows")'.dependencies]
winapi = { version = "0.3", features = ["winuser", "winbase"] }
[target.'cfg(target_os = "macos")'.dependencies]
core-graphics = "0.22"
objc = "0.2"
unicode-segmentation = "1.10"
[target.'cfg(target_os = "linux")'.dependencies]
libxdo-sys = "0.11"
[build-dependencies]
pkg-config = "0.3"
+21
View File
@@ -0,0 +1,21 @@
MIT License
Copyright (c) 2017 pythoneer
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
+121
View File
@@ -0,0 +1,121 @@
# AppVeyor configuration template for Rust using rustup for Rust installation
# https://github.com/starkat99/appveyor-rust
## Operating System (VM environment) ##
# Rust needs at least Visual Studio 2013 AppVeyor OS for MSVC targets.
os: Visual Studio 2015
## Build Matrix ##
# This configuration will setup a build for each channel & target combination (12 windows
# combinations in all).
#
# There are 3 channels: stable, beta, and nightly.
#
# Alternatively, the full version may be specified for the channel to build using that specific
# version (e.g. channel: 1.5.0)
#
# The values for target are the set of windows Rust build targets. Each value is of the form
#
# ARCH-pc-windows-TOOLCHAIN
#
# Where ARCH is the target architecture, either x86_64 or i686, and TOOLCHAIN is the linker
# toolchain to use, either msvc or gnu. See https://www.rust-lang.org/downloads.html#win-foot for
# a description of the toolchain differences.
# See https://github.com/rust-lang-nursery/rustup.rs/#toolchain-specification for description of
# toolchains and host triples.
#
# Comment out channel/target combos you do not wish to build in CI.
#
# You may use the `cargoflags` and `RUSTFLAGS` variables to set additional flags for cargo commands
# and rustc, respectively. For instance, you can uncomment the cargoflags lines in the nightly
# channels to enable unstable features when building for nightly. Or you could add additional
# matrix entries to test different combinations of features.
environment:
matrix:
### MSVC Toolchains ###
# Stable 64-bit MSVC
- channel: stable
target: x86_64-pc-windows-msvc
# Stable 32-bit MSVC
- channel: stable
target: i686-pc-windows-msvc
# Beta 64-bit MSVC
- channel: beta
target: x86_64-pc-windows-msvc
# Beta 32-bit MSVC
- channel: beta
target: i686-pc-windows-msvc
# Nightly 64-bit MSVC
- channel: nightly
target: x86_64-pc-windows-msvc
#cargoflags: --features "unstable"
# Nightly 32-bit MSVC
- channel: nightly
target: i686-pc-windows-msvc
#cargoflags: --features "unstable"
### GNU Toolchains ###
# Stable 64-bit GNU
- channel: stable
target: x86_64-pc-windows-gnu
# Stable 32-bit GNU
- channel: stable
target: i686-pc-windows-gnu
# Beta 64-bit GNU
- channel: beta
target: x86_64-pc-windows-gnu
# Beta 32-bit GNU
- channel: beta
target: i686-pc-windows-gnu
# Nightly 64-bit GNU
- channel: nightly
target: x86_64-pc-windows-gnu
#cargoflags: --features "unstable"
# Nightly 32-bit GNU
- channel: nightly
target: i686-pc-windows-gnu
#cargoflags: --features "unstable"
### Allowed failures ###
# See AppVeyor documentation for specific details. In short, place any channel or targets you wish
# to allow build failures on (usually nightly at least is a wise choice). This will prevent a build
# or test failure in the matching channels/targets from failing the entire build.
matrix:
allow_failures:
- channel: nightly
# If you only care about stable channel build failures, uncomment the following line:
#- channel: beta
## Install Script ##
# This is the most important part of the AppVeyor configuration. This installs the version of Rust
# specified by the 'channel' and 'target' environment variables from the build matrix. This uses
# rustup to install Rust.
#
# For simple configurations, instead of using the build matrix, you can simply set the
# default-toolchain and default-host manually here.
install:
- appveyor DownloadFile https://win.rustup.rs/ -FileName rustup-init.exe
- rustup-init -yv --default-toolchain %channel% --default-host %target%
- set PATH=%PATH%;%USERPROFILE%\.cargo\bin
- rustc -vV
- cargo -vV
## Build Script ##
# 'cargo test' takes care of building for us, so disable AppVeyor's build stage. This prevents
# the "directory does not contain a project or solution file" error.
build: false
# Uses 'cargo test' to run tests and build. Alternatively, the project may call compiled programs
#directly or perform other testing commands. Rust will automatically be placed in the PATH
# environment variable.
test_script:
- cargo test --verbose %cargoflags%
+61
View File
@@ -0,0 +1,61 @@
#[cfg(target_os = "windows")]
fn main() {}
#[cfg(target_os = "macos")]
fn main() {}
#[cfg(target_os = "linux")]
use pkg_config;
#[cfg(target_os = "linux")]
use std::env;
#[cfg(target_os = "linux")]
use std::fs::File;
#[cfg(target_os = "linux")]
use std::io::Write;
#[cfg(target_os = "linux")]
use std::path::Path;
#[cfg(target_os = "linux")]
fn main() {
let libraries = [
"xext",
"gl",
"xcursor",
"xxf86vm",
"xft",
"xinerama",
"xi",
"x11",
"xlib_xcb",
"xmu",
"xrandr",
"xtst",
"xrender",
"xscrnsaver",
"xt",
];
let mut config = String::new();
for lib in libraries.iter() {
let libdir = match pkg_config::get_variable(lib, "libdir") {
Ok(libdir) => format!("Some(\"{}\")", libdir),
Err(_) => "None".to_string(),
};
config.push_str(&format!(
"pub const {}: Option<&'static str> = {};\n",
lib, libdir
));
}
let config = format!("pub mod config {{ pub mod libdir {{\n{}}}\n}}", config);
let out_dir = env::var("OUT_DIR").unwrap();
let dest_path = Path::new(&out_dir).join("config.rs");
let mut f = File::create(&dest_path).unwrap();
f.write_all(&config.into_bytes()).unwrap();
let target = env::var("TARGET").unwrap();
if target.contains("linux") {
println!("cargo:rustc-link-lib=dl");
} else if target.contains("freebsd") || target.contains("dragonfly") {
println!("cargo:rustc-link-lib=c");
}
}
+1
View File
@@ -0,0 +1 @@
wrap_comments = true
+184
View File
@@ -0,0 +1,184 @@
use crate::{Key, KeyboardControllable};
use std::error::Error;
use std::fmt;
/// An error that can occur when parsing DSL
#[derive(Debug, PartialEq, Eq)]
pub enum ParseError {
/// When a tag doesn't exist.
/// Example: {+TEST}{-TEST}
/// ^^^^ ^^^^
UnknownTag(String),
/// When a { is encountered inside a {TAG}.
/// Example: {+HELLO{WORLD}
/// ^
UnexpectedOpen,
/// When a { is never matched with a }.
/// Example: {+SHIFT}Hello{-SHIFT
/// ^
UnmatchedOpen,
/// Opposite of UnmatchedOpen.
/// Example: +SHIFT}Hello{-SHIFT}
/// ^
UnmatchedClose,
}
impl Error for ParseError {
fn description(&self) -> &str {
match *self {
ParseError::UnknownTag(_) => "Unknown tag",
ParseError::UnexpectedOpen => "Unescaped open bracket ({) found inside tag name",
ParseError::UnmatchedOpen => "Unmatched open bracket ({). No matching close (})",
ParseError::UnmatchedClose => "Unmatched close bracket (}). No previous open ({)",
}
}
}
impl fmt::Display for ParseError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(&self.to_string())
}
}
/// Evaluate the DSL. This tokenizes the input and presses the keys.
pub fn eval<K>(enigo: &mut K, input: &str) -> Result<(), ParseError>
where
K: KeyboardControllable,
{
for token in tokenize(input)? {
match token {
Token::Sequence(buffer) => {
for key in buffer.chars() {
enigo.key_click(Key::Layout(key));
}
}
Token::Unicode(buffer) => enigo.key_sequence(&buffer),
Token::KeyUp(key) => enigo.key_up(key),
Token::KeyDown(key) => enigo.key_down(key).unwrap_or(()),
}
}
Ok(())
}
#[derive(Debug, PartialEq, Eq)]
enum Token {
Sequence(String),
Unicode(String),
KeyUp(Key),
KeyDown(Key),
}
fn tokenize(input: &str) -> Result<Vec<Token>, ParseError> {
let mut unicode = false;
let mut tokens = Vec::new();
let mut buffer = String::new();
let mut iter = input.chars().peekable();
fn flush(tokens: &mut Vec<Token>, buffer: String, unicode: bool) {
if !buffer.is_empty() {
if unicode {
tokens.push(Token::Unicode(buffer));
} else {
tokens.push(Token::Sequence(buffer));
}
}
}
while let Some(c) = iter.next() {
if c == '{' {
match iter.next() {
Some('{') => buffer.push('{'),
Some(mut c) => {
flush(&mut tokens, buffer, unicode);
buffer = String::new();
let mut tag = String::new();
loop {
tag.push(c);
match iter.next() {
Some('{') => match iter.peek() {
Some(&'{') => {
iter.next();
c = '{'
}
_ => return Err(ParseError::UnexpectedOpen),
},
Some('}') => match iter.peek() {
Some(&'}') => {
iter.next();
c = '}'
}
_ => break,
},
Some(new) => c = new,
None => return Err(ParseError::UnmatchedOpen),
}
}
match &*tag {
"+UNICODE" => unicode = true,
"-UNICODE" => unicode = false,
"+SHIFT" => tokens.push(Token::KeyDown(Key::Shift)),
"-SHIFT" => tokens.push(Token::KeyUp(Key::Shift)),
"+CTRL" => tokens.push(Token::KeyDown(Key::Control)),
"-CTRL" => tokens.push(Token::KeyUp(Key::Control)),
"+META" => tokens.push(Token::KeyDown(Key::Meta)),
"-META" => tokens.push(Token::KeyUp(Key::Meta)),
"+ALT" => tokens.push(Token::KeyDown(Key::Alt)),
"-ALT" => tokens.push(Token::KeyUp(Key::Alt)),
_ => return Err(ParseError::UnknownTag(tag)),
}
}
None => return Err(ParseError::UnmatchedOpen),
}
} else if c == '}' {
match iter.next() {
Some('}') => buffer.push('}'),
_ => return Err(ParseError::UnmatchedClose),
}
} else {
buffer.push(c);
}
}
flush(&mut tokens, buffer, unicode);
Ok(tokens)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn success() {
assert_eq!(
tokenize("{{Hello World!}} {+CTRL}hi{-CTRL}"),
Ok(vec![
Token::Sequence("{Hello World!} ".into()),
Token::KeyDown(Key::Control),
Token::Sequence("hi".into()),
Token::KeyUp(Key::Control)
])
);
}
#[test]
fn unexpected_open() {
assert_eq!(tokenize("{hello{}world}"), Err(ParseError::UnexpectedOpen));
}
#[test]
fn unmatched_open() {
assert_eq!(
tokenize("{this is going to fail"),
Err(ParseError::UnmatchedOpen)
);
}
#[test]
fn unmatched_close() {
assert_eq!(
tokenize("{+CTRL}{{this}} is going to fail}"),
Err(ParseError::UnmatchedClose)
);
}
}
+552
View File
@@ -0,0 +1,552 @@
//! Enigo lets you simulate mouse and keyboard input-events as if they were
//! made by the actual hardware. The goal is to make it available on different
//! operating systems like Linux, macOS and Windows possibly many more but
//! [Redox](https://redox-os.org/) and *BSD are planned. Please see the
//! [Repo](https://github.com/enigo-rs/enigo) for the current status.
//!
//! I consider this library in an early alpha status, the API will change in
//! in the future. The keyboard handling is far from being very usable. I plan
//! to build a simple
//! [DSL](https://en.wikipedia.org/wiki/Domain-specific_language)
//! that will resemble something like:
//!
//! `"hello {+SHIFT}world{-SHIFT} and break line{ENTER}"`
//!
//! The current status is that you can just print
//! [unicode](http://unicode.org/)
//! characters like [emoji](http://getemoji.com/) without the `{+SHIFT}`
//! [DSL](https://en.wikipedia.org/wiki/Domain-specific_language)
//! or any other "special" key on the Linux, macOS and Windows operating system.
//!
//! Possible use cases could be for testing user interfaces on different
//! platforms,
//! building remote control applications or just automating tasks for user
//! interfaces unaccessible by a public API or scripting language.
//!
//! For the keyboard there are currently two modes you can use. The first mode
//! is represented by the [key_sequence]() function
//! its purpose is to simply write unicode characters. This is independent of
//! the keyboardlayout. Please note that
//! you're not be able to use modifier keys like Control
//! to influence the outcome. If you want to use modifier keys to e.g.
//! copy/paste
//! use the Layout variant. Please note that this is indeed layout dependent.
//! # Examples
//! ```no_run
//! use enigo::*;
//! let mut enigo = Enigo::new();
//! //paste
//! enigo.key_down(Key::Control);
//! enigo.key_click(Key::Layout('v'));
//! enigo.key_up(Key::Control);
//! ```
//!
//! ```no_run
//! use enigo::*;
//! let mut enigo = Enigo::new();
//! enigo.mouse_move_to(500, 200);
//! enigo.mouse_down(MouseButton::Left);
//! enigo.mouse_move_relative(100, 100);
//! enigo.mouse_up(MouseButton::Left);
//! enigo.key_sequence("hello world");
//! ```
#![deny(missing_docs)]
#[cfg(target_os = "macos")]
#[macro_use]
extern crate objc;
// TODO(dustin) use interior mutability not &mut self
#[cfg(target_os = "windows")]
mod win;
#[cfg(target_os = "windows")]
pub use win::Enigo;
#[cfg(target_os = "windows")]
pub use win::ENIGO_INPUT_EXTRA_VALUE;
#[cfg(target_os = "macos")]
mod macos;
#[cfg(target_os = "macos")]
pub use macos::Enigo;
#[cfg(target_os = "macos")]
pub use macos::ENIGO_INPUT_EXTRA_VALUE;
#[cfg(target_os = "linux")]
mod linux;
#[cfg(target_os = "linux")]
pub use crate::linux::Enigo;
/// DSL parser module
pub mod dsl;
#[cfg(feature = "with_serde")]
#[macro_use]
extern crate serde_derive;
#[cfg(feature = "with_serde")]
extern crate serde;
///
pub type ResultType = std::result::Result<(), Box<dyn std::error::Error>>;
#[cfg_attr(feature = "with_serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone, Copy, PartialEq)]
/// MouseButton represents a mouse button,
/// and is used in for example
/// [mouse_click](trait.MouseControllable.html#tymethod.mouse_click).
/// WARNING: Types with the prefix Scroll
/// IS NOT intended to be used, and may not work on
/// all operating systems.
pub enum MouseButton {
/// Left mouse button
Left,
/// Middle mouse button
Middle,
/// Right mouse button
Right,
/// Back mouse button
Back,
/// Forward mouse button
Forward,
/// Scroll up button
ScrollUp,
/// Left right button
ScrollDown,
/// Left right button
ScrollLeft,
/// Left right button
ScrollRight,
}
/// Representing an interface and a set of mouse functions every
/// operating system implementation _should_ implement.
pub trait MouseControllable {
// https://stackoverflow.com/a/33687996
/// Offer the ability to confer concrete type.
fn as_any(&self) -> &dyn std::any::Any;
/// Offer the ability to confer concrete type.
fn as_mut_any(&mut self) -> &mut dyn std::any::Any;
/// Lets the mouse cursor move to the specified x and y coordinates.
///
/// The topleft corner of your monitor screen is x=0 y=0. Move
/// the cursor down the screen by increasing the y and to the right
/// by increasing x coordinate.
///
/// # Example
///
/// ```no_run
/// use enigo::*;
/// let mut enigo = Enigo::new();
/// enigo.mouse_move_to(500, 200);
/// ```
fn mouse_move_to(&mut self, x: i32, y: i32);
/// Lets the mouse cursor move the specified amount in the x and y
/// direction.
///
/// The amount specified in the x and y parameters are added to the
/// current location of the mouse cursor. A positive x values lets
/// the mouse cursor move an amount of `x` pixels to the right. A negative
/// value for `x` lets the mouse cursor go to the left. A positive value
/// of y
/// lets the mouse cursor go down, a negative one lets the mouse cursor go
/// up.
///
/// # Example
///
/// ```no_run
/// use enigo::*;
/// let mut enigo = Enigo::new();
/// enigo.mouse_move_relative(100, 100);
/// ```
fn mouse_move_relative(&mut self, x: i32, y: i32);
/// Push down one of the mouse buttons
///
/// Push down the mouse button specified by the parameter `button` of
/// type [MouseButton](enum.MouseButton.html)
/// and holds it until it is released by
/// [mouse_up](trait.MouseControllable.html#tymethod.mouse_up).
/// Calls to [mouse_move_to](trait.MouseControllable.html#tymethod.
/// mouse_move_to) or
/// [mouse_move_relative](trait.MouseControllable.html#tymethod.
/// mouse_move_relative)
/// will work like expected and will e.g. drag widgets or highlight text.
///
/// # Example
///
/// ```no_run
/// use enigo::*;
/// let mut enigo = Enigo::new();
/// enigo.mouse_down(MouseButton::Left);
/// ```
fn mouse_down(&mut self, button: MouseButton) -> ResultType;
/// Lift up a pushed down mouse button
///
/// Lift up a previously pushed down button (by invoking
/// [mouse_down](trait.MouseControllable.html#tymethod.mouse_down)).
/// If the button was not pushed down or consecutive calls without
/// invoking [mouse_down](trait.MouseControllable.html#tymethod.mouse_down)
/// will emit lift up events. It depends on the
/// operating system whats actually happening my guess is it will just
/// get ignored.
///
/// # Example
///
/// ```no_run
/// use enigo::*;
/// let mut enigo = Enigo::new();
/// enigo.mouse_up(MouseButton::Right);
/// ```
fn mouse_up(&mut self, button: MouseButton);
/// Click a mouse button
///
/// it's essentially just a consecutive invocation of
/// [mouse_down](trait.MouseControllable.html#tymethod.mouse_down) followed
/// by a [mouse_up](trait.MouseControllable.html#tymethod.mouse_up). Just
/// for
/// convenience.
///
/// # Example
///
/// ```no_run
/// use enigo::*;
/// let mut enigo = Enigo::new();
/// enigo.mouse_click(MouseButton::Right);
/// ```
fn mouse_click(&mut self, button: MouseButton);
/// Scroll the mouse (wheel) left or right
///
/// Positive numbers for length lets the mouse wheel scroll to the right
/// and negative ones to the left. The value that is specified translates
/// to `lines` defined by the operating system and is essentially one 15°
/// (click)rotation on the mouse wheel. How many lines it moves depends
/// on the current setting in the operating system.
///
/// # Example
///
/// ```no_run
/// use enigo::*;
/// let mut enigo = Enigo::new();
/// enigo.mouse_scroll_x(2);
/// ```
fn mouse_scroll_x(&mut self, length: i32);
/// Scroll the mouse (wheel) up or down
///
/// Positive numbers for length lets the mouse wheel scroll down
/// and negative ones up. The value that is specified translates
/// to `lines` defined by the operating system and is essentially one 15°
/// (click)rotation on the mouse wheel. How many lines it moves depends
/// on the current setting in the operating system.
///
/// # Example
///
/// ```no_run
/// use enigo::*;
/// let mut enigo = Enigo::new();
/// enigo.mouse_scroll_y(2);
/// ```
fn mouse_scroll_y(&mut self, length: i32);
}
/// A key on the keyboard.
/// For alphabetical keys, use Key::Layout for a system independent key.
/// If a key is missing, you can use the raw keycode with Key::Raw.
#[cfg_attr(feature = "with_serde", derive(Serialize, Deserialize))]
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
pub enum Key {
/// alt key on Linux and Windows (option key on macOS)
Alt,
/// backspace key
Backspace,
/// caps lock key
CapsLock,
// #[deprecated(since = "0.0.12", note = "now renamed to Meta")]
/// command key on macOS (super key on Linux, windows key on Windows)
Command,
/// control key
Control,
/// delete key
Delete,
/// down arrow key
DownArrow,
/// end key
End,
/// escape key (esc)
Escape,
/// F1 key
F1,
/// F10 key
F10,
/// F11 key
F11,
/// F12 key
F12,
/// F2 key
F2,
/// F3 key
F3,
/// F4 key
F4,
/// F5 key
F5,
/// F6 key
F6,
/// F7 key
F7,
/// F8 key
F8,
/// F9 key
F9,
/// home key
Home,
/// left arrow key
LeftArrow,
/// meta key (also known as "windows", "super", and "command")
Meta,
/// option key on macOS (alt key on Linux and Windows)
Option, // deprecated, use Alt instead
/// page down key
PageDown,
/// page up key
PageUp,
/// return key
Return,
/// right arrow key
RightArrow,
/// shift key
Shift,
/// space key
Space,
// #[deprecated(since = "0.0.12", note = "now renamed to Meta")]
/// super key on linux (command key on macOS, windows key on Windows)
Super,
/// tab key (tabulator)
Tab,
/// up arrow key
UpArrow,
// #[deprecated(since = "0.0.12", note = "now renamed to Meta")]
/// windows key on Windows (super key on Linux, command key on macOS)
Windows,
///
Numpad0,
///
Numpad1,
///
Numpad2,
///
Numpad3,
///
Numpad4,
///
Numpad5,
///
Numpad6,
///
Numpad7,
///
Numpad8,
///
Numpad9,
///
Cancel,
///
Clear,
///
Pause,
///
Kana,
///
Hangul,
///
Junja,
///
Final,
///
Hanja,
///
Kanji,
///
Convert,
///
Select,
///
Print,
///
Execute,
///
Snapshot,
///
Insert,
///
Help,
///
Sleep,
///
Separator,
///
VolumeUp,
///
VolumeDown,
///
Mute,
///
Scroll,
/// scroll lock
NumLock,
///
RWin,
///
Apps,
///
Multiply,
///
Add,
///
Subtract,
///
Decimal,
///
Divide,
///
Equals,
///
NumpadEnter,
///
RightShift,
///
RightControl,
///
RightAlt,
///
/// Function, /// mac
/// keyboard layout dependent key
Layout(char),
/// raw keycode eg 0x38
Raw(u16),
}
/// Representing an interface and a set of keyboard functions every
/// operating system implementation _should_ implement.
pub trait KeyboardControllable {
// https://stackoverflow.com/a/33687996
/// Offer the ability to confer concrete type.
fn as_any(&self) -> &dyn std::any::Any;
/// Offer the ability to confer concrete type.
fn as_mut_any(&mut self) -> &mut dyn std::any::Any;
/// Types the string parsed with DSL.
///
/// Typing {+SHIFT}hello{-SHIFT} becomes HELLO.
/// TODO: Full documentation
fn key_sequence_parse(&mut self, sequence: &str)
where
Self: Sized,
{
if let Err(..) = self.key_sequence_parse_try(sequence) {
println!("Could not parse sequence");
}
}
/// Same as key_sequence_parse except returns any errors
fn key_sequence_parse_try(&mut self, sequence: &str) -> Result<(), dsl::ParseError>
where
Self: Sized,
{
dsl::eval(self, sequence)
}
/// Types the string
///
/// Emits keystrokes such that the given string is inputted.
///
/// You can use many unicode here like: ❤️. This works
/// regardless of the current keyboardlayout.
///
/// # Example
///
/// ```no_run
/// use enigo::*;
/// let mut enigo = Enigo::new();
/// enigo.key_sequence("hello world ❤️");
/// ```
fn key_sequence(&mut self, sequence: &str);
/// presses a given key down
fn key_down(&mut self, key: Key) -> ResultType;
/// release a given key formally pressed down by
/// [key_down](trait.KeyboardControllable.html#tymethod.key_down)
fn key_up(&mut self, key: Key);
/// Much like the
/// [key_down](trait.KeyboardControllable.html#tymethod.key_down) and
/// [key_up](trait.KeyboardControllable.html#tymethod.key_up)
/// function they're just invoked consecutively
fn key_click(&mut self, key: Key);
///
fn get_key_state(&mut self, key: Key) -> bool;
}
#[cfg(any(target_os = "android", target_os = "ios"))]
struct Enigo;
impl Enigo {
/// Constructs a new `Enigo` instance.
///
/// # Example
///
/// ```no_run
/// use enigo::*;
/// let mut enigo = Enigo::new();
/// ```
pub fn new() -> Self {
#[cfg(any(target_os = "android", target_os = "ios"))]
return Enigo {};
#[cfg(not(any(target_os = "android", target_os = "ios")))]
Self::default()
}
}
use std::fmt;
impl fmt::Debug for Enigo {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "Enigo")
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_get_key_state() {
let mut enigo = Enigo::new();
let keys = [Key::CapsLock, Key::NumLock];
for k in keys.iter() {
enigo.key_click(k.clone());
let a = enigo.get_key_state(k.clone());
enigo.key_click(k.clone());
let b = enigo.get_key_state(k.clone());
assert!(a != b);
}
let keys = [Key::Control, Key::Alt, Key::Shift];
for k in keys.iter() {
enigo.key_down(k.clone()).ok();
let a = enigo.get_key_state(k.clone());
enigo.key_up(k.clone());
let b = enigo.get_key_state(k.clone());
assert!(a != b);
}
}
}
+4
View File
@@ -0,0 +1,4 @@
mod nix_impl;
mod xdo;
pub use self::nix_impl::Enigo;
+392
View File
@@ -0,0 +1,392 @@
use super::xdo::EnigoXdo;
use crate::{Key, KeyboardControllable, MouseButton, MouseControllable, ResultType};
use std::io::Read;
use tfc::{traits::*, Context as TFC_Context, Key as TFC_Key};
pub type CustomKeyboard = Box<dyn KeyboardControllable + Send>;
pub type CustomMouce = Box<dyn MouseControllable + Send>;
/// The main struct for handling the event emitting
// #[derive(Default)]
pub struct Enigo {
xdo: EnigoXdo,
is_x11: bool,
tfc: Option<TFC_Context>,
custom_keyboard: Option<CustomKeyboard>,
custom_mouse: Option<CustomMouce>,
}
impl Enigo {
/// Get delay of xdo implementation.
pub fn delay(&self) -> u64 {
self.xdo.delay()
}
/// Set delay of xdo implementation.
pub fn set_delay(&mut self, delay: u64) {
self.xdo.set_delay(delay)
}
/// Set custom keyboard.
pub fn set_custom_keyboard(&mut self, custom_keyboard: CustomKeyboard) {
self.custom_keyboard = Some(custom_keyboard)
}
/// Set custom mouse.
pub fn set_custom_mouse(&mut self, custom_mouse: CustomMouce) {
self.custom_mouse = Some(custom_mouse)
}
/// Get custom keyboard.
pub fn get_custom_keyboard(&mut self) -> &mut Option<CustomKeyboard> {
&mut self.custom_keyboard
}
/// Get custom mouse.
pub fn get_custom_mouse(&mut self) -> &mut Option<CustomMouce> {
&mut self.custom_mouse
}
/// Clear remapped keycodes
pub fn tfc_clear_remapped(&mut self) {
if let Some(tfc) = &mut self.tfc {
tfc.recover_remapped_keycodes();
}
}
fn tfc_key_click(&mut self, key: Key) -> ResultType {
if let Some(tfc) = &mut self.tfc {
let res = match key {
Key::Layout(chr) => tfc.unicode_char(chr),
key => {
let tfc_key: TFC_Key = match convert_to_tfc_key(key) {
Some(key) => key,
None => {
return Err(format!("Failed to convert {:?} to TFC_Key", key).into());
}
};
tfc.key_click(tfc_key)
}
};
if res.is_err() {
Err(format!("Failed to click {:?} by tfc", key).into())
} else {
Ok(())
}
} else {
Err("Not Found TFC".into())
}
}
fn tfc_key_down_or_up(&mut self, key: Key, down: bool, up: bool) -> bool {
match &mut self.tfc {
None => false,
Some(tfc) => {
if let Key::Layout(chr) = key {
if down {
if let Err(_) = tfc.unicode_char_down(chr) {
return false;
}
}
if up {
if let Err(_) = tfc.unicode_char_up(chr) {
return false;
}
}
return true;
}
let key = match convert_to_tfc_key(key) {
Some(key) => key,
None => {
return false;
}
};
if down {
if let Err(_) = tfc.key_down(key) {
return false;
}
};
if up {
if let Err(_) = tfc.key_up(key) {
return false;
}
};
return true;
}
}
}
}
impl Default for Enigo {
fn default() -> Self {
let is_x11 = hbb_common::platform::linux::is_x11_or_headless();
Self {
is_x11,
tfc: if is_x11 {
match TFC_Context::new() {
Ok(ctx) => Some(ctx),
Err(..) => {
println!("kbd context error");
None
}
}
} else {
None
},
custom_keyboard: None,
custom_mouse: None,
xdo: EnigoXdo::default(),
}
}
}
impl MouseControllable for Enigo {
fn as_any(&self) -> &dyn std::any::Any {
self
}
fn as_mut_any(&mut self) -> &mut dyn std::any::Any {
self
}
fn mouse_move_to(&mut self, x: i32, y: i32) {
if self.is_x11 {
self.xdo.mouse_move_to(x, y);
} else {
if let Some(mouse) = &mut self.custom_mouse {
mouse.mouse_move_to(x, y)
}
}
}
fn mouse_move_relative(&mut self, x: i32, y: i32) {
if self.is_x11 {
self.xdo.mouse_move_relative(x, y);
} else {
if let Some(mouse) = &mut self.custom_mouse {
mouse.mouse_move_relative(x, y)
}
}
}
fn mouse_down(&mut self, button: MouseButton) -> crate::ResultType {
if self.is_x11 {
self.xdo.mouse_down(button)
} else {
if let Some(mouse) = &mut self.custom_mouse {
mouse.mouse_down(button)
} else {
Ok(())
}
}
}
fn mouse_up(&mut self, button: MouseButton) {
if self.is_x11 {
self.xdo.mouse_up(button)
} else {
if let Some(mouse) = &mut self.custom_mouse {
mouse.mouse_up(button)
}
}
}
fn mouse_click(&mut self, button: MouseButton) {
if self.is_x11 {
self.xdo.mouse_click(button)
} else {
if let Some(mouse) = &mut self.custom_mouse {
mouse.mouse_click(button)
}
}
}
fn mouse_scroll_x(&mut self, length: i32) {
if self.is_x11 {
self.xdo.mouse_scroll_x(length)
} else {
if let Some(mouse) = &mut self.custom_mouse {
mouse.mouse_scroll_x(length)
}
}
}
fn mouse_scroll_y(&mut self, length: i32) {
if self.is_x11 {
self.xdo.mouse_scroll_y(length)
} else {
if let Some(mouse) = &mut self.custom_mouse {
mouse.mouse_scroll_y(length)
}
}
}
}
fn get_led_state(key: Key) -> bool {
let led_file = match key {
// FIXME: the file may be /sys/class/leds/input2 or input5 ...
Key::CapsLock => "/sys/class/leds/input1::capslock/brightness",
Key::NumLock => "/sys/class/leds/input1::numlock/brightness",
_ => {
return false;
}
};
let status = if let Ok(mut file) = std::fs::File::open(&led_file) {
let mut content = String::new();
file.read_to_string(&mut content).ok();
let status = content.trim_end().to_string().parse::<i32>().unwrap_or(0);
status
} else {
0
};
status == 1
}
impl KeyboardControllable for Enigo {
fn as_any(&self) -> &dyn std::any::Any {
self
}
fn as_mut_any(&mut self) -> &mut dyn std::any::Any {
self
}
fn get_key_state(&mut self, key: Key) -> bool {
if self.is_x11 {
self.xdo.get_key_state(key)
} else {
if let Some(keyboard) = &mut self.custom_keyboard {
keyboard.get_key_state(key)
} else {
get_led_state(key)
}
}
}
/// Warning: Get 6^ in French.
fn key_sequence(&mut self, sequence: &str) {
if self.is_x11 {
self.xdo.key_sequence(sequence)
} else {
if let Some(keyboard) = &mut self.custom_keyboard {
keyboard.key_sequence(sequence)
} else {
log::warn!("Enigo::key_sequence: no custom_keyboard set for Wayland!");
}
}
}
fn key_down(&mut self, key: Key) -> crate::ResultType {
if self.is_x11 {
let has_down = self.tfc_key_down_or_up(key, true, false);
if !has_down {
self.xdo.key_down(key)
} else {
Ok(())
}
} else {
if let Some(keyboard) = &mut self.custom_keyboard {
keyboard.key_down(key)
} else {
log::warn!("Enigo::key_down: no custom_keyboard set for Wayland!");
Ok(())
}
}
}
fn key_up(&mut self, key: Key) {
if self.is_x11 {
let has_down = self.tfc_key_down_or_up(key, false, true);
if !has_down {
self.xdo.key_up(key)
}
} else {
if let Some(keyboard) = &mut self.custom_keyboard {
keyboard.key_up(key)
} else {
log::warn!("Enigo::key_up: no custom_keyboard set for Wayland!");
}
}
}
fn key_click(&mut self, key: Key) {
if self.is_x11 {
// X11: try tfc first, then fallback to key_down/key_up
if self.tfc_key_click(key).is_err() {
self.key_down(key).ok();
self.key_up(key);
}
} else {
if let Some(keyboard) = &mut self.custom_keyboard {
keyboard.key_click(key);
} else {
log::warn!("Enigo::key_click: no custom_keyboard set for Wayland!");
}
}
}
}
fn convert_to_tfc_key(key: Key) -> Option<TFC_Key> {
let key = match key {
Key::Alt => TFC_Key::Alt,
Key::Backspace => TFC_Key::DeleteOrBackspace,
Key::CapsLock => TFC_Key::CapsLock,
Key::Control => TFC_Key::Control,
Key::Delete => TFC_Key::ForwardDelete,
Key::DownArrow => TFC_Key::DownArrow,
Key::End => TFC_Key::End,
Key::Escape => TFC_Key::Escape,
Key::F1 => TFC_Key::F1,
Key::F10 => TFC_Key::F10,
Key::F11 => TFC_Key::F11,
Key::F12 => TFC_Key::F12,
Key::F2 => TFC_Key::F2,
Key::F3 => TFC_Key::F3,
Key::F4 => TFC_Key::F4,
Key::F5 => TFC_Key::F5,
Key::F6 => TFC_Key::F6,
Key::F7 => TFC_Key::F7,
Key::F8 => TFC_Key::F8,
Key::F9 => TFC_Key::F9,
Key::Home => TFC_Key::Home,
Key::LeftArrow => TFC_Key::LeftArrow,
Key::PageDown => TFC_Key::PageDown,
Key::PageUp => TFC_Key::PageUp,
Key::Return => TFC_Key::ReturnOrEnter,
Key::RightArrow => TFC_Key::RightArrow,
Key::Shift => TFC_Key::Shift,
Key::Space => TFC_Key::Space,
Key::Tab => TFC_Key::Tab,
Key::UpArrow => TFC_Key::UpArrow,
Key::Numpad0 => TFC_Key::N0,
Key::Numpad1 => TFC_Key::N1,
Key::Numpad2 => TFC_Key::N2,
Key::Numpad3 => TFC_Key::N3,
Key::Numpad4 => TFC_Key::N4,
Key::Numpad5 => TFC_Key::N5,
Key::Numpad6 => TFC_Key::N6,
Key::Numpad7 => TFC_Key::N7,
Key::Numpad8 => TFC_Key::N8,
Key::Numpad9 => TFC_Key::N9,
Key::Decimal => TFC_Key::NumpadDecimal,
Key::Clear => TFC_Key::NumpadClear,
Key::Pause => TFC_Key::Pause,
Key::Print => TFC_Key::Print,
Key::Snapshot => TFC_Key::PrintScreen,
Key::Insert => TFC_Key::Insert,
Key::Scroll => TFC_Key::ScrollLock,
Key::NumLock => TFC_Key::NumLock,
Key::RWin => TFC_Key::Meta,
Key::Apps => TFC_Key::Apps,
Key::Multiply => TFC_Key::NumpadMultiply,
Key::Add => TFC_Key::NumpadPlus,
Key::Subtract => TFC_Key::NumpadMinus,
Key::Divide => TFC_Key::NumpadDivide,
Key::Equals => TFC_Key::NumpadEquals,
Key::NumpadEnter => TFC_Key::NumpadEnter,
Key::RightShift => TFC_Key::RightShift,
Key::RightControl => TFC_Key::RightControl,
Key::RightAlt => TFC_Key::RightAlt,
Key::Command | Key::Super | Key::Windows | Key::Meta => TFC_Key::Meta,
_ => {
return None;
}
};
Some(key)
}
#[test]
fn test_key_seq() {
// Get 6^ in French.
let mut en = Enigo::new();
en.key_sequence("^^");
}
+459
View File
@@ -0,0 +1,459 @@
//! XDO-based input emulation for Linux.
//!
//! This module uses libxdo-sys (patched to use dynamic loading stub) for input emulation.
//! The stub handles dynamic loading of libxdo, so we just call the functions directly.
//!
//! If libxdo is not available at runtime, all operations become no-ops.
use crate::{Key, KeyboardControllable, MouseButton, MouseControllable};
use hbb_common::libc::c_int;
use hbb_common::x11::xlib::{Display, XCloseDisplay, XGetPointerMapping, XOpenDisplay};
use libxdo_sys::{self, xdo_t, CURRENTWINDOW};
use std::{borrow::Cow, ffi::CString};
/// Default delay per keypress in microseconds.
/// This value is passed to libxdo functions and must fit in `useconds_t` (u32).
const DEFAULT_DELAY: u64 = 12000;
/// Maximum allowed delay value (u32::MAX as u64).
const MAX_DELAY: u64 = u32::MAX as u64;
fn mousebutton(button: MouseButton) -> c_int {
match button {
MouseButton::Left => 1,
MouseButton::Middle => 2,
MouseButton::Right => 3,
MouseButton::ScrollUp => 4,
MouseButton::ScrollDown => 5,
MouseButton::ScrollLeft => 6,
MouseButton::ScrollRight => 7,
MouseButton::Back => 8,
MouseButton::Forward => 9,
}
}
/// Minimum number of buttons the X11 core pointer must support.
/// Buttons 8 (Back) and 9 (Forward) are needed for mouse side buttons.
const MIN_POINTER_BUTTONS: usize = 9;
/// Check that the X11 core pointer's button map includes at least 9 buttons
/// so that `XTestFakeButtonEvent` can simulate Back (8) and Forward (9).
///
/// RustDesk's uinput "Mouse passthrough" device normally provides enough
/// buttons, but we log a warning if the map is too small so the issue is
/// diagnosable. `XSetPointerMapping` cannot extend the button count (its
/// length must match `XGetPointerMapping`), so we only diagnose here.
fn check_x11_button_map() {
// Skip on non-X11 sessions to avoid noisy "XOpenDisplay failed" warnings
// on pure Wayland or headless environments without $DISPLAY.
if std::env::var_os("DISPLAY").is_none() {
return;
}
let display: *mut Display = unsafe { XOpenDisplay(std::ptr::null()) };
if display.is_null() {
log::warn!("XOpenDisplay failed, cannot check button map");
return;
}
let mut current_map = [0u8; 32];
let nbuttons =
unsafe { XGetPointerMapping(display, current_map.as_mut_ptr(), current_map.len() as i32) };
unsafe { XCloseDisplay(display) };
if nbuttons < 0 {
log::warn!("XGetPointerMapping failed (returned {nbuttons})");
return;
}
let nbuttons = nbuttons as usize;
if nbuttons >= MIN_POINTER_BUTTONS {
log::info!("X11 pointer has {nbuttons} buttons, side buttons supported");
} else {
log::warn!(
"X11 pointer has only {nbuttons} buttons (need {MIN_POINTER_BUTTONS}); \
back/forward side buttons may not work until a device with more buttons is added"
);
}
}
/// The main struct for handling the event emitting
pub(super) struct EnigoXdo {
xdo: *mut xdo_t,
delay: u64,
}
// This is safe, we have a unique pointer.
// TODO: use Unique<c_char> once stable.
unsafe impl Send for EnigoXdo {}
impl Default for EnigoXdo {
/// Create a new EnigoXdo instance.
///
/// If libxdo is not available, the xdo pointer will be null and all
/// input operations will be no-ops.
fn default() -> Self {
let xdo = unsafe { libxdo_sys::xdo_new(std::ptr::null()) };
if xdo.is_null() {
log::warn!("Failed to create xdo context, xdo functions will be disabled");
} else {
log::info!("xdo context created successfully");
check_x11_button_map();
}
Self {
xdo,
delay: DEFAULT_DELAY,
}
}
}
impl EnigoXdo {
/// Get the delay per keypress in microseconds.
///
/// Default value is 12000 (12ms). This is Linux-specific.
pub fn delay(&self) -> u64 {
self.delay
}
/// Set the delay per keypress in microseconds.
///
/// This is Linux-specific. The value is clamped to `u32::MAX` (approximately
/// 4295 seconds) because libxdo uses `useconds_t` which is typically `u32`.
///
/// # Arguments
/// * `delay` - Delay in microseconds. Values exceeding `u32::MAX` will be clamped.
pub fn set_delay(&mut self, delay: u64) {
self.delay = delay.min(MAX_DELAY);
if delay > MAX_DELAY {
log::warn!(
"delay value {} exceeds maximum {}, clamped",
delay,
MAX_DELAY
);
}
}
}
impl Drop for EnigoXdo {
fn drop(&mut self) {
if !self.xdo.is_null() {
unsafe {
libxdo_sys::xdo_free(self.xdo);
}
}
}
}
impl MouseControllable for EnigoXdo {
fn as_any(&self) -> &dyn std::any::Any {
self
}
fn as_mut_any(&mut self) -> &mut dyn std::any::Any {
self
}
fn mouse_move_to(&mut self, x: i32, y: i32) {
if self.xdo.is_null() {
return;
}
unsafe {
libxdo_sys::xdo_move_mouse(self.xdo as *const _, x, y, 0);
}
}
fn mouse_move_relative(&mut self, x: i32, y: i32) {
if self.xdo.is_null() {
return;
}
unsafe {
libxdo_sys::xdo_move_mouse_relative(self.xdo as *const _, x, y);
}
}
fn mouse_down(&mut self, button: MouseButton) -> crate::ResultType {
if self.xdo.is_null() {
return Ok(());
}
unsafe {
libxdo_sys::xdo_mouse_down(self.xdo as *const _, CURRENTWINDOW, mousebutton(button));
}
Ok(())
}
fn mouse_up(&mut self, button: MouseButton) {
if self.xdo.is_null() {
return;
}
unsafe {
libxdo_sys::xdo_mouse_up(self.xdo as *const _, CURRENTWINDOW, mousebutton(button));
}
}
fn mouse_click(&mut self, button: MouseButton) {
if self.xdo.is_null() {
return;
}
unsafe {
libxdo_sys::xdo_click_window(self.xdo as *const _, CURRENTWINDOW, mousebutton(button));
}
}
fn mouse_scroll_x(&mut self, length: i32) {
let button;
let mut length = length;
if length < 0 {
button = MouseButton::ScrollLeft;
} else {
button = MouseButton::ScrollRight;
}
if length < 0 {
length = -length;
}
for _ in 0..length {
self.mouse_click(button);
}
}
fn mouse_scroll_y(&mut self, length: i32) {
let button;
let mut length = length;
if length < 0 {
button = MouseButton::ScrollUp;
} else {
button = MouseButton::ScrollDown;
}
if length < 0 {
length = -length;
}
for _ in 0..length {
self.mouse_click(button);
}
}
}
fn keysequence<'a>(key: Key) -> Cow<'a, str> {
if let Key::Layout(c) = key {
return Cow::Owned(format!("U{:X}", c as u32));
}
if let Key::Raw(k) = key {
return Cow::Owned(format!("{}", k as u16));
}
#[allow(deprecated)]
// I mean duh, we still need to support deprecated keys until they're removed
// https://www.rubydoc.info/gems/xdo/XDo/Keyboard
// https://gitlab.com/cunidev/gestures/-/wikis/xdotool-list-of-key-codes
Cow::Borrowed(match key {
Key::Alt => "Alt",
Key::Backspace => "BackSpace",
Key::CapsLock => "Caps_Lock",
Key::Control => "Control",
Key::Delete => "Delete",
Key::DownArrow => "Down",
Key::End => "End",
Key::Escape => "Escape",
Key::F1 => "F1",
Key::F10 => "F10",
Key::F11 => "F11",
Key::F12 => "F12",
Key::F2 => "F2",
Key::F3 => "F3",
Key::F4 => "F4",
Key::F5 => "F5",
Key::F6 => "F6",
Key::F7 => "F7",
Key::F8 => "F8",
Key::F9 => "F9",
Key::Home => "Home",
//Key::Layout(_) => unreachable!(),
Key::LeftArrow => "Left",
Key::Option => "Option",
Key::PageDown => "Page_Down",
Key::PageUp => "Page_Up",
//Key::Raw(_) => unreachable!(),
Key::Return => "Return",
Key::RightArrow => "Right",
Key::Shift => "Shift",
Key::Space => "space",
Key::Tab => "Tab",
Key::UpArrow => "Up",
Key::Numpad0 => "U30", //"KP_0",
Key::Numpad1 => "U31", //"KP_1",
Key::Numpad2 => "U32", //"KP_2",
Key::Numpad3 => "U33", //"KP_3",
Key::Numpad4 => "U34", //"KP_4",
Key::Numpad5 => "U35", //"KP_5",
Key::Numpad6 => "U36", //"KP_6",
Key::Numpad7 => "U37", //"KP_7",
Key::Numpad8 => "U38", //"KP_8",
Key::Numpad9 => "U39", //"KP_9",
Key::Decimal => "U2E", //"KP_Decimal",
Key::Cancel => "Cancel",
Key::Clear => "Clear",
Key::Pause => "Pause",
Key::Kana => "Kana",
Key::Hangul => "Hangul",
Key::Junja => "",
Key::Final => "",
Key::Hanja => "Hanja",
Key::Kanji => "Kanji",
Key::Convert => "",
Key::Select => "Select",
Key::Print => "Print",
Key::Execute => "Execute",
Key::Snapshot => "3270_PrintScreen",
Key::Insert => "Insert",
Key::Help => "Help",
Key::Sleep => "",
Key::Separator => "KP_Separator",
Key::VolumeUp => "",
Key::VolumeDown => "",
Key::Mute => "",
Key::Scroll => "Scroll_Lock",
Key::NumLock => "Num_Lock",
Key::RWin => "Super_R",
Key::Apps => "Menu",
Key::Multiply => "KP_Multiply",
Key::Add => "KP_Add",
Key::Subtract => "KP_Subtract",
Key::Divide => "KP_Divide",
Key::Equals => "KP_Equal",
Key::NumpadEnter => "KP_Enter",
Key::RightShift => "Shift_R",
Key::RightControl => "Control_R",
Key::RightAlt => "Alt_R",
Key::Command | Key::Super | Key::Windows | Key::Meta => "Super",
_ => "",
})
}
impl KeyboardControllable for EnigoXdo {
fn as_any(&self) -> &dyn std::any::Any {
self
}
fn as_mut_any(&mut self) -> &mut dyn std::any::Any {
self
}
fn get_key_state(&mut self, key: Key) -> bool {
if self.xdo.is_null() {
return false;
}
/*
// modifier keys mask
pub const ShiftMask: c_uint = 0x01;
pub const LockMask: c_uint = 0x02;
pub const ControlMask: c_uint = 0x04;
pub const Mod1Mask: c_uint = 0x08;
pub const Mod2Mask: c_uint = 0x10;
pub const Mod3Mask: c_uint = 0x20;
pub const Mod4Mask: c_uint = 0x40;
pub const Mod5Mask: c_uint = 0x80;
*/
let mod_shift = 1 << 0;
let mod_lock = 1 << 1;
let mod_control = 1 << 2;
let mod_alt = 1 << 3;
let mod_numlock = 1 << 4;
let mod_meta = 1 << 6;
let mask = unsafe { libxdo_sys::xdo_get_input_state(self.xdo as *const _) };
match key {
Key::Shift => mask & mod_shift != 0,
Key::CapsLock => mask & mod_lock != 0,
Key::Control => mask & mod_control != 0,
Key::Alt => mask & mod_alt != 0,
Key::NumLock => mask & mod_numlock != 0,
Key::Meta => mask & mod_meta != 0,
_ => false,
}
}
fn key_sequence(&mut self, sequence: &str) {
if self.xdo.is_null() {
return;
}
if let Ok(string) = CString::new(sequence) {
unsafe {
libxdo_sys::xdo_enter_text_window(
self.xdo as *const _,
CURRENTWINDOW,
string.as_ptr(),
self.delay as libxdo_sys::useconds_t,
);
}
}
}
fn key_down(&mut self, key: Key) -> crate::ResultType {
if self.xdo.is_null() {
return Ok(());
}
let string = CString::new(&*keysequence(key))?;
unsafe {
libxdo_sys::xdo_send_keysequence_window_down(
self.xdo as *const _,
CURRENTWINDOW,
string.as_ptr(),
self.delay as libxdo_sys::useconds_t,
);
}
Ok(())
}
fn key_up(&mut self, key: Key) {
if self.xdo.is_null() {
return;
}
if let Ok(string) = CString::new(&*keysequence(key)) {
unsafe {
libxdo_sys::xdo_send_keysequence_window_up(
self.xdo as *const _,
CURRENTWINDOW,
string.as_ptr(),
self.delay as libxdo_sys::useconds_t,
);
}
}
}
fn key_click(&mut self, key: Key) {
if self.xdo.is_null() {
return;
}
if let Ok(string) = CString::new(&*keysequence(key)) {
unsafe {
libxdo_sys::xdo_send_keysequence_window(
self.xdo as *const _,
CURRENTWINDOW,
string.as_ptr(),
self.delay as libxdo_sys::useconds_t,
);
}
}
}
fn key_sequence_parse(&mut self, sequence: &str)
where
Self: Sized,
{
if let Err(..) = self.key_sequence_parse_try(sequence) {
println!("Could not parse sequence");
}
}
fn key_sequence_parse_try(&mut self, sequence: &str) -> Result<(), crate::dsl::ParseError>
where
Self: Sized,
{
crate::dsl::eval(self, sequence)
}
}
+120
View File
@@ -0,0 +1,120 @@
// https://stackoverflow.com/questions/3202629/where-can-i-find-a-list-of-mac-virtual-key-codes
/* keycodes for keys that are independent of keyboard layout */
#![allow(non_upper_case_globals)]
#![allow(dead_code)]
pub const kVK_Return: u16 = 0x24;
pub const kVK_Tab: u16 = 0x30;
pub const kVK_Space: u16 = 0x31;
pub const kVK_Delete: u16 = 0x33;
pub const kVK_Escape: u16 = 0x35;
pub const kVK_Command: u16 = 0x37;
pub const kVK_Shift: u16 = 0x38;
pub const kVK_CapsLock: u16 = 0x39;
pub const kVK_Option: u16 = 0x3A;
pub const kVK_Control: u16 = 0x3B;
pub const kVK_RightShift: u16 = 0x3C;
pub const kVK_RightOption: u16 = 0x3D;
pub const kVK_RightControl: u16 = 0x3E;
pub const kVK_Function: u16 = 0x3F;
pub const kVK_F17: u16 = 0x40;
pub const kVK_VolumeUp: u16 = 0x48;
pub const kVK_VolumeDown: u16 = 0x49;
pub const kVK_Mute: u16 = 0x4A;
pub const kVK_F18: u16 = 0x4F;
pub const kVK_F19: u16 = 0x50;
pub const kVK_F20: u16 = 0x5A;
pub const kVK_F5: u16 = 0x60;
pub const kVK_F6: u16 = 0x61;
pub const kVK_F7: u16 = 0x62;
pub const kVK_F3: u16 = 0x63;
pub const kVK_F8: u16 = 0x64;
pub const kVK_F9: u16 = 0x65;
pub const kVK_F11: u16 = 0x67;
pub const kVK_F13: u16 = 0x69;
pub const kVK_F16: u16 = 0x6A;
pub const kVK_F14: u16 = 0x6B;
pub const kVK_F10: u16 = 0x6D;
pub const kVK_F12: u16 = 0x6F;
pub const kVK_F15: u16 = 0x71;
pub const kVK_Help: u16 = 0x72;
pub const kVK_Home: u16 = 0x73;
pub const kVK_PageUp: u16 = 0x74;
pub const kVK_ForwardDelete: u16 = 0x75;
pub const kVK_F4: u16 = 0x76;
pub const kVK_End: u16 = 0x77;
pub const kVK_F2: u16 = 0x78;
pub const kVK_PageDown: u16 = 0x79;
pub const kVK_F1: u16 = 0x7A;
pub const kVK_LeftArrow: u16 = 0x7B;
pub const kVK_RightArrow: u16 = 0x7C;
pub const kVK_DownArrow: u16 = 0x7D;
pub const kVK_UpArrow: u16 = 0x7E;
pub const kVK_ANSI_Keypad0: u16 = 0x52;
pub const kVK_ANSI_Keypad1: u16 = 0x53;
pub const kVK_ANSI_Keypad2: u16 = 0x54;
pub const kVK_ANSI_Keypad3: u16 = 0x55;
pub const kVK_ANSI_Keypad4: u16 = 0x56;
pub const kVK_ANSI_Keypad5: u16 = 0x57;
pub const kVK_ANSI_Keypad6: u16 = 0x58;
pub const kVK_ANSI_Keypad7: u16 = 0x59;
pub const kVK_ANSI_Keypad8: u16 = 0x5B;
pub const kVK_ANSI_Keypad9: u16 = 0x5C;
pub const kVK_ANSI_KeypadClear: u16 = 0x47;
pub const kVK_ANSI_KeypadDecimal: u16 = 0x41;
pub const kVK_ANSI_KeypadMultiply: u16 = 0x43;
pub const kVK_ANSI_KeypadPlus: u16 = 0x45;
pub const kVK_ANSI_KeypadDivide: u16 = 0x4B;
pub const kVK_ANSI_KeypadEnter: u16 = 0x4C;
pub const kVK_ANSI_KeypadMinus: u16 = 0x4E;
pub const kVK_ANSI_KeypadEquals: u16 = 0x51;
pub const kVK_RIGHT_COMMAND: u16 = 0x36;
pub const kVK_ANSI_A : u16 = 0x00;
pub const kVK_ANSI_S : u16 = 0x01;
pub const kVK_ANSI_D : u16 = 0x02;
pub const kVK_ANSI_F : u16 = 0x03;
pub const kVK_ANSI_H : u16 = 0x04;
pub const kVK_ANSI_G : u16 = 0x05;
pub const kVK_ANSI_Z : u16 = 0x06;
pub const kVK_ANSI_X : u16 = 0x07;
pub const kVK_ANSI_C : u16 = 0x08;
pub const kVK_ANSI_V : u16 = 0x09;
pub const kVK_ANSI_B : u16 = 0x0B;
pub const kVK_ANSI_Q : u16 = 0x0C;
pub const kVK_ANSI_W : u16 = 0x0D;
pub const kVK_ANSI_E : u16 = 0x0E;
pub const kVK_ANSI_R : u16 = 0x0F;
pub const kVK_ANSI_Y : u16 = 0x10;
pub const kVK_ANSI_T : u16 = 0x11;
pub const kVK_ANSI_1 : u16 = 0x12;
pub const kVK_ANSI_2 : u16 = 0x13;
pub const kVK_ANSI_3 : u16 = 0x14;
pub const kVK_ANSI_4 : u16 = 0x15;
pub const kVK_ANSI_6 : u16 = 0x16;
pub const kVK_ANSI_5 : u16 = 0x17;
pub const kVK_ANSI_Equal : u16 = 0x18;
pub const kVK_ANSI_9 : u16 = 0x19;
pub const kVK_ANSI_7 : u16 = 0x1A;
pub const kVK_ANSI_Minus : u16 = 0x1B;
pub const kVK_ANSI_8 : u16 = 0x1C;
pub const kVK_ANSI_0 : u16 = 0x1D;
pub const kVK_ANSI_RightBracket : u16 = 0x1E;
pub const kVK_ANSI_O : u16 = 0x1F;
pub const kVK_ANSI_U : u16 = 0x20;
pub const kVK_ANSI_LeftBracket : u16 = 0x21;
pub const kVK_ANSI_I : u16 = 0x22;
pub const kVK_ANSI_P : u16 = 0x23;
pub const kVK_ANSI_L : u16 = 0x25;
pub const kVK_ANSI_J : u16 = 0x26;
pub const kVK_ANSI_Quote : u16 = 0x27;
pub const kVK_ANSI_K : u16 = 0x28;
pub const kVK_ANSI_Semicolon : u16 = 0x29;
pub const kVK_ANSI_Backslash : u16 = 0x2A;
pub const kVK_ANSI_Comma : u16 = 0x2B;
pub const kVK_ANSI_Slash : u16 = 0x2C;
pub const kVK_ANSI_N : u16 = 0x2D;
pub const kVK_ANSI_M : u16 = 0x2E;
pub const kVK_ANSI_Period : u16 = 0x2F;
pub const kVK_ANSI_Grave : u16 = 0x32;
+864
View File
@@ -0,0 +1,864 @@
use core_graphics;
// TODO(dustin): use only the things i need
use self::core_graphics::display::*;
use self::core_graphics::event::*;
use self::core_graphics::event_source::*;
use std::collections::HashMap as Map;
use std::ffi::c_void;
use std::ffi::CStr;
use std::os::raw::*;
use std::ptr::null_mut;
use crate::macos::keycodes::*;
use crate::{Key, KeyboardControllable, MouseButton, MouseControllable};
use objc::runtime::Class;
struct MyCGEvent;
type TISInputSourceRef = *mut c_void;
type CFDataRef = *const c_void;
type OptionBits = u32;
type OSStatus = i32;
type UniChar = u16;
type UniCharCount = usize;
type Boolean = c_uchar;
type CFStringEncoding = u32;
#[repr(C)]
#[derive(Debug, Copy, Clone)]
struct __CFString([u8; 0]);
type CFStringRef = *const __CFString;
#[allow(non_upper_case_globals)]
const kCFStringEncodingUTF8: u32 = 134_217_984;
#[allow(non_upper_case_globals)]
const kUCKeyActionDisplay: u16 = 3;
#[allow(non_upper_case_globals)]
const kUCKeyTranslateDeadKeysBit: OptionBits = 1 << 31;
const BUF_LEN: usize = 4;
const MOUSE_EVENT_BUTTON_NUMBER_BACK: i64 = 3;
const MOUSE_EVENT_BUTTON_NUMBER_FORWARD: i64 = 4;
/// The event source user data value of cgevent.
pub const ENIGO_INPUT_EXTRA_VALUE: i64 = 100;
#[allow(improper_ctypes)]
#[allow(non_snake_case)]
#[link(name = "ApplicationServices", kind = "framework")]
#[link(name = "Carbon", kind = "framework")]
extern "C" {
fn CFDataGetBytePtr(theData: CFDataRef) -> *const u8;
fn TISCopyCurrentKeyboardInputSource() -> TISInputSourceRef;
fn TISCopyCurrentKeyboardLayoutInputSource() -> TISInputSourceRef;
fn TISCopyCurrentASCIICapableKeyboardLayoutInputSource() -> TISInputSourceRef;
static kTISPropertyUnicodeKeyLayoutData: *mut c_void;
static kTISPropertyInputSourceID: *mut c_void;
fn UCKeyTranslate(
keyLayoutPtr: *const u8, //*const UCKeyboardLayout,
virtualKeyCode: u16,
keyAction: u16,
modifierKeyState: u32,
keyboardType: u32,
keyTranslateOptions: OptionBits,
deadKeyState: *mut u32,
maxStringLength: UniCharCount,
actualStringLength: *mut UniCharCount,
unicodeString: *mut [UniChar; BUF_LEN],
) -> OSStatus;
fn LMGetKbdType() -> u8;
fn CFStringGetCString(
theString: CFStringRef,
buffer: *mut c_char,
bufferSize: CFIndex,
encoding: CFStringEncoding,
) -> Boolean;
fn CGEventPost(tapLocation: CGEventTapLocation, event: *mut MyCGEvent);
// Actually return CFDataRef which is const here, but for coding convenience, return *mut c_void
fn TISGetInputSourceProperty(source: TISInputSourceRef, property: *const c_void)
-> *mut c_void;
// not present in servo/core-graphics
fn CGEventCreateScrollWheelEvent(
source: &CGEventSourceRef,
units: ScrollUnit,
wheelCount: u32,
wheel1: i32,
...
) -> *mut MyCGEvent;
fn CGEventSourceKeyState(stateID: i32, key: u16) -> bool;
}
#[repr(C)]
#[derive(Clone, Copy)]
struct NSPoint {
x: f64,
y: f64,
}
// not present in servo/core-graphics
#[allow(dead_code)]
#[derive(Debug)]
enum ScrollUnit {
Pixel = 0,
Line = 1,
}
// hack
/// The main struct for handling the event emitting
pub struct Enigo {
event_source: Option<CGEventSource>,
double_click_interval: u32,
last_click_time: Option<std::time::Instant>,
multiple_click: i64,
ignore_flags: bool,
flags: CGEventFlags,
char_to_vkey_map: Map<String, Map<char, CGKeyCode>>,
}
impl Enigo {
/// Set if ignore flags when posting events.
pub fn set_ignore_flags(&mut self, ignore: bool) {
self.ignore_flags = ignore;
}
///
pub fn reset_flag(&mut self) {
self.flags = CGEventFlags::CGEventFlagNull;
}
///
pub fn add_flag(&mut self, key: &Key) {
let flag = match key {
&Key::CapsLock => CGEventFlags::CGEventFlagAlphaShift,
&Key::Shift => CGEventFlags::CGEventFlagShift,
&Key::Control => CGEventFlags::CGEventFlagControl,
&Key::Alt => CGEventFlags::CGEventFlagAlternate,
&Key::Meta => CGEventFlags::CGEventFlagCommand,
&Key::NumLock => CGEventFlags::CGEventFlagNumericPad,
_ => CGEventFlags::CGEventFlagNull,
};
self.flags |= flag;
}
// Just check F11 for minimal changes.
// Since enigo (legacy mode) is deprecated, it is currently in maintenance only.
fn post(&self, event: CGEvent, keycode: Option<u16>) {
if keycode == Some(kVK_F11) {
// Some key events require the flags to work.
// We can't simply set the flag to `CGEventFlags::CGEventFlagNull`.
// eg. `F11` requires flags `CGEventFlags::CGEventFlagSecondaryFn | 0x20000000` to work.
self.post_event(event, false);
} else {
// macOS system may use the previous event flag to generate the next event.
// Only found this issue when locking the screen.
// When we use enigo to lock the screen, the next mouse event will have the flag
// `CGEventFlagControl | CGEventFlagCommand | 0x20000000`.
// The key event will also have the flag `CGEventFlagControl | CGEventFlagCommand | 0x20000000`.
// Therefore, we need to set the flag to `event.set_flags(self.flags)` to avoid this.
self.post_event(event, true);
}
}
fn post_event(&self, event: CGEvent, force_flags: bool) {
if !self.ignore_flags && (force_flags || self.flags != CGEventFlags::CGEventFlagNull) {
event.set_flags(self.flags);
}
event.set_integer_value_field(EventField::EVENT_SOURCE_USER_DATA, ENIGO_INPUT_EXTRA_VALUE);
event.post(CGEventTapLocation::HID);
}
}
impl Default for Enigo {
fn default() -> Self {
let mut double_click_interval = 500;
if let Some(ns_event) = Class::get("NSEvent") {
let tm: f64 = unsafe { msg_send![ns_event, doubleClickInterval] };
if tm > 0. {
double_click_interval = (tm * 1000.) as u32;
log::info!("double click interval: {}ms", double_click_interval);
}
}
Self {
// TODO(dustin): return error rather than panic here
event_source: if let Ok(src) =
CGEventSource::new(CGEventSourceStateID::CombinedSessionState)
{
Some(src)
} else {
None
},
double_click_interval,
multiple_click: 1,
last_click_time: None,
ignore_flags: false,
flags: CGEventFlags::CGEventFlagNull,
char_to_vkey_map: Default::default(),
}
}
}
impl MouseControllable for Enigo {
fn as_any(&self) -> &dyn std::any::Any {
self
}
fn as_mut_any(&mut self) -> &mut dyn std::any::Any {
self
}
fn mouse_move_to(&mut self, x: i32, y: i32) {
// For absolute movement, we don't set delta values
// This maintains backward compatibility
self.mouse_move_to_impl(x, y, None);
}
fn mouse_move_relative(&mut self, x: i32, y: i32) {
let (display_width, display_height) = Self::main_display_size();
let (current_x, y_inv) = Self::mouse_location_raw_coords();
let current_y = (display_height as i32) - y_inv;
// Use saturating arithmetic to prevent overflow/wraparound
let mut new_x = current_x.saturating_add(x);
let mut new_y = current_y.saturating_add(y);
// Define screen center and edge margins for cursor reset
let center_x = (display_width / 2) as i32;
let center_y = (display_height / 2) as i32;
// Margin calculation: 5% of the smaller screen dimension with a minimum of 50px.
// This provides a comfortable buffer zone to detect when the cursor is approaching
// screen edges, allowing us to reset it to center before it hits the boundary.
// This ensures continuous relative mouse movement without getting stuck at edges.
let margin = (display_width.min(display_height) / 20).max(50) as i32;
// Check if cursor is approaching screen boundaries
// Use saturating_sub to prevent negative thresholds on very small displays
let right = (display_width as i32).saturating_sub(margin);
let bottom = (display_height as i32).saturating_sub(margin);
let near_edge = new_x < margin
|| new_x > right
|| new_y < margin
|| new_y > bottom;
if near_edge {
// Reset cursor to screen center to allow continuous movement
// The delta values are still passed correctly for games/apps
new_x = center_x;
new_y = center_y;
}
// Clamp to screen bounds as a safety measure.
// Use saturating_sub(1) to ensure coordinates don't exceed the last valid pixel.
let max_x = (display_width as i32).saturating_sub(1).max(0);
let max_y = (display_height as i32).saturating_sub(1).max(0);
new_x = new_x.clamp(0, max_x);
new_y = new_y.clamp(0, max_y);
// Pass delta values for relative movement
// This is critical for browser Pointer Lock API support
// The delta fields (MOUSE_EVENT_DELTA_X/Y) are used by browsers
// to calculate movementX/Y in Pointer Lock mode
self.mouse_move_to_impl(new_x, new_y, Some((x, y)));
}
fn mouse_down(&mut self, button: MouseButton) -> crate::ResultType {
let now = std::time::Instant::now();
if let Some(t) = self.last_click_time {
if t.elapsed().as_millis() as u32 <= self.double_click_interval {
self.multiple_click += 1;
} else {
self.multiple_click = 1;
}
}
self.last_click_time = Some(now);
let (current_x, current_y) = Self::mouse_location();
let (button, event_type, btn_value) = match button {
MouseButton::Left => (CGMouseButton::Left, CGEventType::LeftMouseDown, None),
MouseButton::Middle => (CGMouseButton::Center, CGEventType::OtherMouseDown, None),
MouseButton::Right => (CGMouseButton::Right, CGEventType::RightMouseDown, None),
MouseButton::Back => (
CGMouseButton::Left,
CGEventType::OtherMouseDown,
Some(MOUSE_EVENT_BUTTON_NUMBER_BACK),
),
MouseButton::Forward => (
CGMouseButton::Left,
CGEventType::OtherMouseDown,
Some(MOUSE_EVENT_BUTTON_NUMBER_FORWARD),
),
_ => {
log::info!("Unsupported button {:?}", button);
return Ok(());
}
};
let dest = CGPoint::new(current_x as f64, current_y as f64);
if let Some(src) = self.event_source.as_ref() {
if let Ok(event) = CGEvent::new_mouse_event(src.clone(), event_type, dest, button) {
if self.multiple_click > 1 {
event.set_integer_value_field(
EventField::MOUSE_EVENT_CLICK_STATE,
self.multiple_click,
);
}
if let Some(v) = btn_value {
event.set_integer_value_field(EventField::MOUSE_EVENT_BUTTON_NUMBER, v);
}
self.post(event, None);
}
}
Ok(())
}
fn mouse_up(&mut self, button: MouseButton) {
let (current_x, current_y) = Self::mouse_location();
let (button, event_type, btn_value) = match button {
MouseButton::Left => (CGMouseButton::Left, CGEventType::LeftMouseUp, None),
MouseButton::Middle => (CGMouseButton::Center, CGEventType::OtherMouseUp, None),
MouseButton::Right => (CGMouseButton::Right, CGEventType::RightMouseUp, None),
MouseButton::Back => (
CGMouseButton::Left,
CGEventType::OtherMouseUp,
Some(MOUSE_EVENT_BUTTON_NUMBER_BACK),
),
MouseButton::Forward => (
CGMouseButton::Left,
CGEventType::OtherMouseUp,
Some(MOUSE_EVENT_BUTTON_NUMBER_FORWARD),
),
_ => {
log::info!("Unsupported button {:?}", button);
return;
}
};
let dest = CGPoint::new(current_x as f64, current_y as f64);
if let Some(src) = self.event_source.as_ref() {
if let Ok(event) = CGEvent::new_mouse_event(src.clone(), event_type, dest, button) {
if self.multiple_click > 1 {
event.set_integer_value_field(
EventField::MOUSE_EVENT_CLICK_STATE,
self.multiple_click,
);
}
if let Some(v) = btn_value {
event.set_integer_value_field(EventField::MOUSE_EVENT_BUTTON_NUMBER, v);
}
self.post(event, None);
}
}
}
fn mouse_click(&mut self, button: MouseButton) {
self.mouse_down(button).ok();
self.mouse_up(button);
}
fn mouse_scroll_x(&mut self, length: i32) {
let mut scroll_direction = -1; // 1 left -1 right;
let mut length = length;
if length < 0 {
length *= -1;
scroll_direction *= -1;
}
if let Some(src) = self.event_source.as_ref() {
for _ in 0..length {
unsafe {
let mouse_ev = CGEventCreateScrollWheelEvent(
&src,
ScrollUnit::Line,
2, // CGWheelCount 1 = y 2 = xy 3 = xyz
0,
scroll_direction,
);
CGEventPost(CGEventTapLocation::HID, mouse_ev);
CFRelease(mouse_ev as *const std::ffi::c_void);
}
}
}
}
fn mouse_scroll_y(&mut self, length: i32) {
let mut scroll_direction = -1; // 1 left -1 right;
let mut length = length;
if length < 0 {
length *= -1;
scroll_direction *= -1;
}
if let Some(src) = self.event_source.as_ref() {
for _ in 0..length {
unsafe {
let mouse_ev = CGEventCreateScrollWheelEvent(
&src,
ScrollUnit::Line,
1, // CGWheelCount 1 = y 2 = xy 3 = xyz
scroll_direction,
);
CGEventPost(CGEventTapLocation::HID, mouse_ev);
CFRelease(mouse_ev as *const std::ffi::c_void);
}
}
}
}
}
// https://stackoverflow.
// com/questions/1918841/how-to-convert-ascii-character-to-cgkeycode
impl KeyboardControllable for Enigo {
fn as_any(&self) -> &dyn std::any::Any {
self
}
fn as_mut_any(&mut self) -> &mut dyn std::any::Any {
self
}
fn key_sequence(&mut self, sequence: &str) {
// NOTE(dustin): This is a fix for issue https://github.com/enigo-rs/enigo/issues/68
// TODO(dustin): This could be improved by aggregating 20 bytes worth of graphemes at a time
// but i am unsure what would happen for grapheme clusters greater than 20 bytes ...
use unicode_segmentation::UnicodeSegmentation;
let clusters = UnicodeSegmentation::graphemes(sequence, true).collect::<Vec<&str>>();
for cluster in clusters {
if let Some(src) = self.event_source.as_ref() {
if let Ok(event) = CGEvent::new_keyboard_event(src.clone(), 0, true) {
event.set_string(cluster);
self.post(event, None);
}
}
}
}
fn key_click(&mut self, key: Key) {
let keycode = self.key_to_keycode(key);
if keycode == u16::MAX {
return;
}
if let Some(src) = self.event_source.as_ref() {
if let Ok(event) = CGEvent::new_keyboard_event(src.clone(), keycode, true) {
self.post(event, Some(keycode));
}
if let Ok(event) = CGEvent::new_keyboard_event(src.clone(), keycode, false) {
self.post(event, Some(keycode));
}
}
}
fn key_down(&mut self, key: Key) -> crate::ResultType {
let code = self.key_to_keycode(key);
if code == u16::MAX {
return Err("".into());
}
if let Some(src) = self.event_source.as_ref() {
if let Ok(event) = CGEvent::new_keyboard_event(src.clone(), code, true) {
self.post(event, Some(code));
}
}
Ok(())
}
fn key_up(&mut self, key: Key) {
let code = self.key_to_keycode(key);
if let Some(src) = self.event_source.as_ref() {
if let Ok(event) = CGEvent::new_keyboard_event(src.clone(), code, false) {
self.post(event, Some(code));
}
}
}
fn get_key_state(&mut self, key: Key) -> bool {
let keycode = self.key_to_keycode(key);
unsafe { CGEventSourceKeyState(1, keycode) }
}
}
impl Enigo {
fn pressed_buttons() -> usize {
if let Some(ns_event) = Class::get("NSEvent") {
unsafe { msg_send![ns_event, pressedMouseButtons] }
} else {
0
}
}
/// Internal implementation for mouse movement with optional delta values.
///
/// The `delta` parameter is crucial for browser Pointer Lock API support.
/// When a browser enters Pointer Lock mode, it reads mouse delta values
/// (MOUSE_EVENT_DELTA_X/Y) directly from CGEvent to calculate movementX/Y.
/// Without setting these fields, the browser sees zero movement.
fn mouse_move_to_impl(&mut self, x: i32, y: i32, delta: Option<(i32, i32)>) {
let pressed = Self::pressed_buttons();
// Determine event type and corresponding mouse button based on pressed buttons.
// The CGMouseButton must match the event type for drag events.
let (event_type, button) = if pressed & 1 > 0 {
(CGEventType::LeftMouseDragged, CGMouseButton::Left)
} else if pressed & 2 > 0 {
(CGEventType::RightMouseDragged, CGMouseButton::Right)
} else if pressed & 4 > 0 {
(CGEventType::OtherMouseDragged, CGMouseButton::Center)
} else {
(CGEventType::MouseMoved, CGMouseButton::Left) // Button doesn't matter for MouseMoved
};
let dest = CGPoint::new(x as f64, y as f64);
if let Some(src) = self.event_source.as_ref() {
if let Ok(event) =
CGEvent::new_mouse_event(src.clone(), event_type, dest, button)
{
// Set delta fields for relative mouse movement
// This is essential for Pointer Lock API in browsers
if let Some((dx, dy)) = delta {
event.set_integer_value_field(EventField::MOUSE_EVENT_DELTA_X, dx as i64);
event.set_integer_value_field(EventField::MOUSE_EVENT_DELTA_Y, dy as i64);
}
self.post(event, None);
}
}
}
/// Fetches the `(width, height)` in pixels of the main display
pub fn main_display_size() -> (usize, usize) {
let display_id = unsafe { CGMainDisplayID() };
let width = unsafe { CGDisplayPixelsWide(display_id) };
let height = unsafe { CGDisplayPixelsHigh(display_id) };
(width, height)
}
/// Returns the current mouse location in Cocoa coordinates which have Y
/// inverted from the Carbon coordinates used in the rest of the API.
/// This function exists so that mouse_move_relative only has to fetch
/// the screen size once.
fn mouse_location_raw_coords() -> (i32, i32) {
if let Some(ns_event) = Class::get("NSEvent") {
let pt: NSPoint = unsafe { msg_send![ns_event, mouseLocation] };
(pt.x as i32, pt.y as i32)
} else {
(0, 0)
}
}
/// The mouse coordinates in points, only works on the main display
pub fn mouse_location() -> (i32, i32) {
let (x, y_inv) = Self::mouse_location_raw_coords();
let (_, display_height) = Self::main_display_size();
(x, (display_height as i32) - y_inv)
}
fn key_to_keycode(&mut self, key: Key) -> CGKeyCode {
#[allow(deprecated)]
// I mean duh, we still need to support deprecated keys until they're removed
match key {
Key::Alt => kVK_Option,
Key::Backspace => kVK_Delete,
Key::CapsLock => kVK_CapsLock,
Key::Control => kVK_Control,
Key::Delete => kVK_ForwardDelete,
Key::DownArrow => kVK_DownArrow,
Key::End => kVK_End,
Key::Escape => kVK_Escape,
Key::F1 => kVK_F1,
Key::F10 => kVK_F10,
Key::F11 => kVK_F11,
Key::F12 => kVK_F12,
Key::F2 => kVK_F2,
Key::F3 => kVK_F3,
Key::F4 => kVK_F4,
Key::F5 => kVK_F5,
Key::F6 => kVK_F6,
Key::F7 => kVK_F7,
Key::F8 => kVK_F8,
Key::F9 => kVK_F9,
Key::Home => kVK_Home,
Key::LeftArrow => kVK_LeftArrow,
Key::Option => kVK_Option,
Key::PageDown => kVK_PageDown,
Key::PageUp => kVK_PageUp,
Key::Return => kVK_Return,
Key::RightArrow => kVK_RightArrow,
Key::Shift => kVK_Shift,
Key::Space => kVK_Space,
Key::Tab => kVK_Tab,
Key::UpArrow => kVK_UpArrow,
Key::Numpad0 => kVK_ANSI_Keypad0,
Key::Numpad1 => kVK_ANSI_Keypad1,
Key::Numpad2 => kVK_ANSI_Keypad2,
Key::Numpad3 => kVK_ANSI_Keypad3,
Key::Numpad4 => kVK_ANSI_Keypad4,
Key::Numpad5 => kVK_ANSI_Keypad5,
Key::Numpad6 => kVK_ANSI_Keypad6,
Key::Numpad7 => kVK_ANSI_Keypad7,
Key::Numpad8 => kVK_ANSI_Keypad8,
Key::Numpad9 => kVK_ANSI_Keypad9,
Key::Mute => kVK_Mute,
Key::VolumeDown => kVK_VolumeUp,
Key::VolumeUp => kVK_VolumeDown,
Key::Help => kVK_Help,
Key::Snapshot => kVK_F13,
Key::Clear => kVK_ANSI_KeypadClear,
Key::Decimal => kVK_ANSI_KeypadDecimal,
Key::Multiply => kVK_ANSI_KeypadMultiply,
Key::Add => kVK_ANSI_KeypadPlus,
Key::Divide => kVK_ANSI_KeypadDivide,
Key::NumpadEnter => kVK_ANSI_KeypadEnter,
Key::Subtract => kVK_ANSI_KeypadMinus,
Key::Equals => kVK_ANSI_KeypadEquals,
Key::NumLock => kVK_ANSI_KeypadClear,
Key::RWin => kVK_RIGHT_COMMAND,
Key::RightShift => kVK_RightShift,
Key::RightControl => kVK_RightControl,
Key::RightAlt => kVK_RightOption,
Key::Raw(raw_keycode) => raw_keycode,
Key::Layout(c) => self.map_key_board(c),
Key::Super | Key::Command | Key::Windows | Key::Meta => kVK_Command,
_ => u16::MAX,
}
}
#[inline]
fn map_key_board(&mut self, ch: char) -> CGKeyCode {
// no idea why below char not working with shift, https://github.com/rustdesk/rustdesk/issues/406#issuecomment-1145157327
// seems related to numpad char
if ch == '-' || ch == '=' || ch == '.' || ch == '/' || (ch >= '0' && ch <= '9') {
return self.map_key_board_en(ch);
}
let mut code = u16::MAX;
unsafe {
let (keyboard, layout) = get_layout();
if !keyboard.is_null() && !layout.is_null() {
let name_ref = TISGetInputSourceProperty(keyboard, kTISPropertyInputSourceID);
if !name_ref.is_null() {
let name = get_string(name_ref as _);
if let Some(name) = name {
if let Some(m) = self.char_to_vkey_map.get(&name) {
code = *m.get(&ch).unwrap_or(&u16::MAX);
} else {
let m = get_map(&name, layout);
code = *m.get(&ch).unwrap_or(&u16::MAX);
self.char_to_vkey_map.insert(name.clone(), m);
}
}
}
}
if !keyboard.is_null() {
CFRelease(keyboard);
}
}
if code != u16::MAX {
return code;
}
self.map_key_board_en(ch)
}
#[inline]
fn map_key_board_en(&mut self, ch: char) -> CGKeyCode {
match ch {
'a' => kVK_ANSI_A,
'b' => kVK_ANSI_B,
'c' => kVK_ANSI_C,
'd' => kVK_ANSI_D,
'e' => kVK_ANSI_E,
'f' => kVK_ANSI_F,
'g' => kVK_ANSI_G,
'h' => kVK_ANSI_H,
'i' => kVK_ANSI_I,
'j' => kVK_ANSI_J,
'k' => kVK_ANSI_K,
'l' => kVK_ANSI_L,
'm' => kVK_ANSI_M,
'n' => kVK_ANSI_N,
'o' => kVK_ANSI_O,
'p' => kVK_ANSI_P,
'q' => kVK_ANSI_Q,
'r' => kVK_ANSI_R,
's' => kVK_ANSI_S,
't' => kVK_ANSI_T,
'u' => kVK_ANSI_U,
'v' => kVK_ANSI_V,
'w' => kVK_ANSI_W,
'x' => kVK_ANSI_X,
'y' => kVK_ANSI_Y,
'z' => kVK_ANSI_Z,
'0' => kVK_ANSI_0,
'1' => kVK_ANSI_1,
'2' => kVK_ANSI_2,
'3' => kVK_ANSI_3,
'4' => kVK_ANSI_4,
'5' => kVK_ANSI_5,
'6' => kVK_ANSI_6,
'7' => kVK_ANSI_7,
'8' => kVK_ANSI_8,
'9' => kVK_ANSI_9,
'-' => kVK_ANSI_Minus,
'=' => kVK_ANSI_Equal,
'[' => kVK_ANSI_LeftBracket,
']' => kVK_ANSI_RightBracket,
'\\' => kVK_ANSI_Backslash,
';' => kVK_ANSI_Semicolon,
'\'' => kVK_ANSI_Quote,
',' => kVK_ANSI_Comma,
'.' => kVK_ANSI_Period,
'/' => kVK_ANSI_Slash,
'`' => kVK_ANSI_Grave,
_ => u16::MAX,
}
}
#[inline]
fn mouse_scroll_impl(&mut self, length: i32, is_track_pad: bool, is_horizontal: bool) {
let mut scroll_direction = -1; // 1 left -1 right;
let mut length = length;
if length < 0 {
length *= -1;
scroll_direction *= -1;
}
if let Some(src) = self.event_source.as_ref() {
for _ in 0..length {
unsafe {
let units = if is_track_pad {
ScrollUnit::Pixel
} else {
ScrollUnit::Line
};
let mouse_ev = if is_horizontal {
CGEventCreateScrollWheelEvent(
&src,
units,
2, // CGWheelCount 1 = y 2 = xy 3 = xyz
0,
scroll_direction,
)
} else {
CGEventCreateScrollWheelEvent(
&src,
units,
1, // CGWheelCount 1 = y 2 = xy 3 = xyz
scroll_direction,
)
};
CGEventPost(CGEventTapLocation::HID, mouse_ev);
CFRelease(mouse_ev as *const std::ffi::c_void);
}
}
}
}
/// handle scroll vertically
pub fn mouse_scroll_y(&mut self, length: i32, is_track_pad: bool) {
self.mouse_scroll_impl(length, is_track_pad, false)
}
/// handle scroll horizontally
pub fn mouse_scroll_x(&mut self, length: i32, is_track_pad: bool) {
self.mouse_scroll_impl(length, is_track_pad, true)
}
}
#[inline]
unsafe fn get_string(cf_string: CFStringRef) -> Option<String> {
if !cf_string.is_null() {
let mut buf: [i8; 255] = [0; 255];
let success = CFStringGetCString(
cf_string,
buf.as_mut_ptr(),
buf.len() as _,
kCFStringEncodingUTF8,
);
if success != 0 {
let name: &CStr = CStr::from_ptr(buf.as_ptr());
if let Ok(name) = name.to_str() {
return Some(name.to_string());
}
}
}
None
}
#[inline]
unsafe fn get_layout() -> (TISInputSourceRef, *const u8) {
let mut keyboard = TISCopyCurrentKeyboardInputSource();
let mut layout = null_mut();
if !keyboard.is_null() {
layout = TISGetInputSourceProperty(keyboard, kTISPropertyUnicodeKeyLayoutData);
}
if layout.is_null() {
if !keyboard.is_null() {
CFRelease(keyboard);
}
// https://github.com/microsoft/vscode/issues/23833
keyboard = TISCopyCurrentKeyboardLayoutInputSource();
if !keyboard.is_null() {
layout = TISGetInputSourceProperty(keyboard, kTISPropertyUnicodeKeyLayoutData);
}
}
if layout.is_null() {
if !keyboard.is_null() {
CFRelease(keyboard);
}
keyboard = TISCopyCurrentASCIICapableKeyboardLayoutInputSource();
if !keyboard.is_null() {
layout = TISGetInputSourceProperty(keyboard, kTISPropertyUnicodeKeyLayoutData);
}
}
if layout.is_null() {
if !keyboard.is_null() {
CFRelease(keyboard);
}
return (null_mut(), null_mut());
}
let layout_ptr = CFDataGetBytePtr(layout as _);
if layout_ptr.is_null() {
if !keyboard.is_null() {
CFRelease(keyboard);
}
return (null_mut(), null_mut());
}
(keyboard, layout_ptr)
}
#[inline]
fn get_map(name: &str, layout: *const u8) -> Map<char, CGKeyCode> {
log::info!("Create keyboard map for {}", name);
let mut keys_down: u32 = 0;
let mut map = Map::new();
for keycode in 0..128 {
let mut buff = [0_u16; BUF_LEN];
let kb_type = unsafe { LMGetKbdType() };
let mut length: UniCharCount = 0;
let _retval = unsafe {
UCKeyTranslate(
layout,
keycode,
kUCKeyActionDisplay as _,
0,
kb_type as _,
kUCKeyTranslateDeadKeysBit as _,
&mut keys_down,
BUF_LEN,
&mut length,
&mut buff,
)
};
if length > 0 {
if let Ok(str) = String::from_utf16(&buff[..length]) {
if let Some(chr) = str.chars().next() {
map.insert(chr, keycode as _);
}
}
}
}
map
}
unsafe impl Send for Enigo {}
+4
View File
@@ -0,0 +1,4 @@
mod macos_impl;
pub mod keycodes;
pub use self::macos_impl::{Enigo, ENIGO_INPUT_EXTRA_VALUE};
+83
View File
@@ -0,0 +1,83 @@
#![allow(dead_code)]
// https://msdn.microsoft.com/en-us/library/windows/desktop/dd375731
//
// JP/KR mapping https://github.com/TigerVNC/tigervnc/blob/1a008c1380305648ab50f1d99e73439747e9d61d/vncviewer/win32.c#L267
// altgr handle: https://github.com/TigerVNC/tigervnc/blob/dccb95f345f7a9c5aa785a19d1bfa3fdecd8f8e0/vncviewer/Viewport.cxx#L1066
pub const EVK_RETURN: u16 = 0x0D;
pub const EVK_TAB: u16 = 0x09;
pub const EVK_SPACE: u16 = 0x20;
pub const EVK_BACK: u16 = 0x08;
pub const EVK_ESCAPE: u16 = 0x1b;
pub const EVK_LWIN: u16 = 0x5b;
pub const EVK_SHIFT: u16 = 0x10;
//pub const EVK_LSHIFT: u16 = 0xa0;
pub const EVK_RSHIFT: u16 = 0xa1;
//pub const EVK_LMENU: u16 = 0xa4;
pub const EVK_RMENU: u16 = 0xa5;
pub const EVK_CAPITAL: u16 = 0x14;
pub const EVK_MENU: u16 = 0x12;
pub const EVK_LCONTROL: u16 = 0xa2;
pub const EVK_RCONTROL: u16 = 0xa3;
pub const EVK_HOME: u16 = 0x24;
pub const EVK_PRIOR: u16 = 0x21;
pub const EVK_NEXT: u16 = 0x22;
pub const EVK_END: u16 = 0x23;
pub const EVK_LEFT: u16 = 0x25;
pub const EVK_RIGHT: u16 = 0x27;
pub const EVK_UP: u16 = 0x26;
pub const EVK_DOWN: u16 = 0x28;
pub const EVK_DELETE: u16 = 0x2E;
pub const EVK_F1: u16 = 0x70;
pub const EVK_F2: u16 = 0x71;
pub const EVK_F3: u16 = 0x72;
pub const EVK_F4: u16 = 0x73;
pub const EVK_F5: u16 = 0x74;
pub const EVK_F6: u16 = 0x75;
pub const EVK_F7: u16 = 0x76;
pub const EVK_F8: u16 = 0x77;
pub const EVK_F9: u16 = 0x78;
pub const EVK_F10: u16 = 0x79;
pub const EVK_F11: u16 = 0x7a;
pub const EVK_F12: u16 = 0x7b;
pub const EVK_NUMPAD0: u16 = 0x60;
pub const EVK_NUMPAD1: u16 = 0x61;
pub const EVK_NUMPAD2: u16 = 0x62;
pub const EVK_NUMPAD3: u16 = 0x63;
pub const EVK_NUMPAD4: u16 = 0x64;
pub const EVK_NUMPAD5: u16 = 0x65;
pub const EVK_NUMPAD6: u16 = 0x66;
pub const EVK_NUMPAD7: u16 = 0x67;
pub const EVK_NUMPAD8: u16 = 0x68;
pub const EVK_NUMPAD9: u16 = 0x69;
pub const EVK_CANCEL: u16 = 0x03;
pub const EVK_CLEAR: u16 = 0x0C;
pub const EVK_PAUSE: u16 = 0x13;
pub const EVK_KANA: u16 = 0x15;
pub const EVK_HANGUL: u16 = 0x15;
pub const EVK_JUNJA: u16 = 0x17;
pub const EVK_FINAL: u16 = 0x18;
pub const EVK_HANJA: u16 = 0x19;
pub const EVK_KANJI: u16 = 0x19;
pub const EVK_CONVERT: u16 = 0x1C;
pub const EVK_SELECT: u16 = 0x29;
pub const EVK_PRINT: u16 = 0x2A;
pub const EVK_EXECUTE: u16 = 0x2B;
pub const EVK_SNAPSHOT: u16 = 0x2C;
pub const EVK_INSERT: u16 = 0x2D;
pub const EVK_HELP: u16 = 0x2F;
pub const EVK_SLEEP: u16 = 0x5F;
pub const EVK_SEPARATOR: u16 = 0x6C;
pub const EVK_VOLUME_MUTE: u16 = 0xAD;
pub const EVK_VOLUME_DOWN: u16 = 0xAE;
pub const EVK_VOLUME_UP: u16 = 0xAF;
pub const EVK_NUMLOCK: u16 = 0x90;
pub const EVK_SCROLL: u16 = 0x91;
pub const EVK_RWIN: u16 = 0x5C;
pub const EVK_APPS: u16 = 0x5D;
pub const EVK_ADD: u16 = 0x6B;
pub const EVK_MULTIPLY: u16 = 0x6A;
pub const EVK_SUBTRACT: u16 = 0x6D;
pub const EVK_DECIMAL: u16 = 0x6E;
pub const EVK_DIVIDE: u16 = 0x6F;
pub const EVK_PERIOD: u16 = 0xBE;
+4
View File
@@ -0,0 +1,4 @@
mod win_impl;
pub mod keycodes;
pub use self::win_impl::{Enigo, ENIGO_INPUT_EXTRA_VALUE};
+478
View File
@@ -0,0 +1,478 @@
use self::winapi::ctypes::c_int;
use self::winapi::shared::{basetsd::ULONG_PTR, minwindef::*, windef::*};
use self::winapi::um::winbase::*;
use self::winapi::um::winuser::*;
use winapi;
use crate::win::keycodes::*;
use crate::{Key, KeyboardControllable, MouseButton, MouseControllable};
use std::mem::*;
extern "system" {
pub fn GetLastError() -> DWORD;
}
/// The main struct for handling the event emitting
#[derive(Default)]
pub struct Enigo;
static mut LAYOUT: HKL = std::ptr::null_mut();
/// The dwExtraInfo value in keyboard and mouse structure that used in SendInput()
pub const ENIGO_INPUT_EXTRA_VALUE: ULONG_PTR = 100;
fn mouse_event(flags: u32, data: u32, dx: i32, dy: i32) -> DWORD {
let mut u = INPUT_u::default();
unsafe {
*u.mi_mut() = MOUSEINPUT {
dx,
dy,
mouseData: data,
dwFlags: flags,
time: 0,
dwExtraInfo: ENIGO_INPUT_EXTRA_VALUE,
};
}
let mut input = INPUT {
type_: INPUT_MOUSE,
u,
};
unsafe { SendInput(1, &mut input as LPINPUT, size_of::<INPUT>() as c_int) }
}
fn keybd_event(mut flags: u32, vk: u16, scan: u16) -> DWORD {
let mut scan = scan;
unsafe {
// https://github.com/rustdesk/rustdesk/issues/366
if scan == 0 {
if LAYOUT.is_null() {
let current_window_thread_id =
GetWindowThreadProcessId(GetForegroundWindow(), std::ptr::null_mut());
LAYOUT = GetKeyboardLayout(current_window_thread_id);
}
scan = MapVirtualKeyExW(vk as _, 0, LAYOUT) as _;
}
}
if flags & KEYEVENTF_UNICODE == 0 {
if scan >> 8 == 0xE0 || scan >> 8 == 0xE1 {
flags |= winapi::um::winuser::KEYEVENTF_EXTENDEDKEY;
}
}
let mut union: INPUT_u = unsafe { std::mem::zeroed() };
unsafe {
*union.ki_mut() = KEYBDINPUT {
wVk: vk,
wScan: scan,
dwFlags: flags,
time: 0,
dwExtraInfo: ENIGO_INPUT_EXTRA_VALUE,
};
}
let mut inputs = [INPUT {
type_: INPUT_KEYBOARD,
u: union,
}; 1];
unsafe {
SendInput(
inputs.len() as UINT,
inputs.as_mut_ptr(),
size_of::<INPUT>() as c_int,
)
}
}
fn get_error() -> String {
unsafe {
let buff_size = 256;
let mut buff: Vec<u16> = Vec::with_capacity(buff_size);
buff.resize(buff_size, 0);
let errno = GetLastError();
let chars_copied = FormatMessageW(
FORMAT_MESSAGE_IGNORE_INSERTS
| FORMAT_MESSAGE_FROM_SYSTEM
| FORMAT_MESSAGE_ARGUMENT_ARRAY,
std::ptr::null(),
errno,
0,
buff.as_mut_ptr(),
(buff_size + 1) as u32,
std::ptr::null_mut(),
);
if chars_copied == 0 {
return "".to_owned();
}
let mut curr_char: usize = chars_copied as usize;
while curr_char > 0 {
let ch = buff[curr_char];
if ch >= ' ' as u16 {
break;
}
curr_char -= 1;
}
let sl = std::slice::from_raw_parts(buff.as_ptr(), curr_char);
let err_msg = String::from_utf16(sl);
return err_msg.unwrap_or("".to_owned());
}
}
impl MouseControllable for Enigo {
fn as_any(&self) -> &dyn std::any::Any {
self
}
fn as_mut_any(&mut self) -> &mut dyn std::any::Any {
self
}
fn mouse_move_to(&mut self, x: i32, y: i32) {
mouse_event(
MOUSEEVENTF_MOVE | MOUSEEVENTF_ABSOLUTE | MOUSEEVENTF_VIRTUALDESK,
0,
(x - unsafe { GetSystemMetrics(SM_XVIRTUALSCREEN) }) * 65535
/ unsafe { GetSystemMetrics(SM_CXVIRTUALSCREEN) },
(y - unsafe { GetSystemMetrics(SM_YVIRTUALSCREEN) }) * 65535
/ unsafe { GetSystemMetrics(SM_CYVIRTUALSCREEN) },
);
}
fn mouse_move_relative(&mut self, x: i32, y: i32) {
mouse_event(MOUSEEVENTF_MOVE, 0, x, y);
}
fn mouse_down(&mut self, button: MouseButton) -> crate::ResultType {
let res = mouse_event(
match button {
MouseButton::Left => MOUSEEVENTF_LEFTDOWN,
MouseButton::Middle => MOUSEEVENTF_MIDDLEDOWN,
MouseButton::Right => MOUSEEVENTF_RIGHTDOWN,
MouseButton::Back => MOUSEEVENTF_XDOWN,
MouseButton::Forward => MOUSEEVENTF_XDOWN,
_ => {
log::info!("Unsupported button {:?}", button);
return Ok(());
}
},
match button {
MouseButton::Back => XBUTTON1 as u32,
MouseButton::Forward => XBUTTON2 as u32,
_ => 0,
},
0,
0,
);
if res == 0 {
let err = get_error();
if !err.is_empty() {
return Err(err.into());
}
}
Ok(())
}
fn mouse_up(&mut self, button: MouseButton) {
mouse_event(
match button {
MouseButton::Left => MOUSEEVENTF_LEFTUP,
MouseButton::Middle => MOUSEEVENTF_MIDDLEUP,
MouseButton::Right => MOUSEEVENTF_RIGHTUP,
MouseButton::Back => MOUSEEVENTF_XUP,
MouseButton::Forward => MOUSEEVENTF_XUP,
_ => {
log::info!("Unsupported button {:?}", button);
return;
}
},
match button {
MouseButton::Back => XBUTTON1 as _,
MouseButton::Forward => XBUTTON2 as _,
_ => 0,
},
0,
0,
);
}
fn mouse_click(&mut self, button: MouseButton) {
self.mouse_down(button).ok();
self.mouse_up(button);
}
fn mouse_scroll_x(&mut self, length: i32) {
mouse_event(MOUSEEVENTF_HWHEEL, length as _, 0, 0);
}
fn mouse_scroll_y(&mut self, length: i32) {
mouse_event(MOUSEEVENTF_WHEEL, length as _, 0, 0);
}
}
impl KeyboardControllable for Enigo {
fn as_any(&self) -> &dyn std::any::Any {
self
}
fn as_mut_any(&mut self) -> &mut dyn std::any::Any {
self
}
fn key_sequence(&mut self, sequence: &str) {
let mut buffer = [0; 2];
for c in sequence.chars() {
// Windows uses uft-16 encoding. We need to check
// for variable length characters. As such some
// characters can be 32 bit long and those are
// encoded in such called hight and low surrogates
// each 16 bit wide that needs to be send after
// another to the SendInput function without
// being interrupted by "keyup"
let result = c.encode_utf16(&mut buffer);
if result.len() == 1 {
self.unicode_key_click(result[0]);
} else {
for utf16_surrogate in result {
self.unicode_key_down(utf16_surrogate.clone());
}
// do i need to produce a keyup?
// self.unicode_key_up(0);
}
}
}
fn key_click(&mut self, key: Key) {
let vk = self.key_to_keycode(key);
keybd_event(0, vk, 0);
keybd_event(KEYEVENTF_KEYUP, vk, 0);
}
fn key_down(&mut self, key: Key) -> crate::ResultType {
match &key {
Key::Layout(c) => {
// to-do: dup code
// https://github.com/rustdesk/rustdesk/blob/1bc0dd791ed8344997024dc46626bd2ca7df73d2/src/server/input_service.rs#L1348
let code = self.get_layoutdependent_keycode(*c);
if code as u16 != 0xFFFF {
let vk = code & 0x00FF;
let flag = code >> 8;
let modifiers = [Key::Shift, Key::Control, Key::Alt];
let mod_len = modifiers.len();
for pos in 0..mod_len {
if flag & (0x0001 << pos) != 0 {
self.key_down(modifiers[pos])?;
}
}
let res = keybd_event(0, vk, 0);
let err = if res == 0 { get_error() } else { "".to_owned() };
for pos in 0..mod_len {
let rpos = mod_len - 1 - pos;
if flag & (0x0001 << rpos) != 0 {
self.key_up(modifiers[rpos]);
}
}
if !err.is_empty() {
return Err(err.into());
}
} else {
return Err(format!("Failed to get keycode of {}", c).into());
}
}
_ => {
let code = self.key_to_keycode(key);
if code == 0 || code == 65535 {
return Err("".into());
}
let res = keybd_event(0, code, 0);
if res == 0 {
let err = get_error();
if !err.is_empty() {
return Err(err.into());
}
}
}
}
Ok(())
}
fn key_up(&mut self, key: Key) {
match key {
Key::Layout(c) => {
let code = self.get_layoutdependent_keycode(c);
if code as u16 != 0xFFFF {
let vk = code & 0x00FF;
keybd_event(KEYEVENTF_KEYUP, vk, 0);
}
}
_ => {
keybd_event(KEYEVENTF_KEYUP, self.key_to_keycode(key), 0);
}
}
}
fn get_key_state(&mut self, key: Key) -> bool {
let keycode = self.key_to_keycode(key);
let x = unsafe { GetKeyState(keycode as _) };
if key == Key::CapsLock || key == Key::NumLock || key == Key::Scroll {
return (x & 0x1) == 0x1;
}
return (x as u16 & 0x8000) == 0x8000;
}
}
impl Enigo {
/// Gets the (width, height) of the main display in screen coordinates
/// (pixels).
///
/// # Example
///
/// ```no_run
/// use enigo::*;
/// let mut size = Enigo::main_display_size();
/// ```
pub fn main_display_size() -> (usize, usize) {
let w = unsafe { GetSystemMetrics(SM_CXSCREEN) as usize };
let h = unsafe { GetSystemMetrics(SM_CYSCREEN) as usize };
(w, h)
}
/// Gets the location of mouse in screen coordinates (pixels).
///
/// # Example
///
/// ```no_run
/// use enigo::*;
/// let mut location = Enigo::mouse_location();
/// ```
pub fn mouse_location() -> (i32, i32) {
let mut point = POINT { x: 0, y: 0 };
let result = unsafe { GetCursorPos(&mut point) };
if result != 0 {
(point.x, point.y)
} else {
(0, 0)
}
}
fn unicode_key_click(&self, unicode_char: u16) {
self.unicode_key_down(unicode_char);
self.unicode_key_up(unicode_char);
}
fn unicode_key_down(&self, unicode_char: u16) {
keybd_event(KEYEVENTF_UNICODE, 0, unicode_char);
}
fn unicode_key_up(&self, unicode_char: u16) {
keybd_event(KEYEVENTF_UNICODE | KEYEVENTF_KEYUP, 0, unicode_char);
}
fn key_to_keycode(&self, key: Key) -> u16 {
// do not use the codes from crate winapi they're
// wrongly typed with i32 instead of i16 use the
// ones provided by win/keycodes.rs that are prefixed
// with an 'E' infront of the original name
#[allow(deprecated)]
// I mean duh, we still need to support deprecated keys until they're removed
match key {
Key::Alt => EVK_MENU,
Key::Backspace => EVK_BACK,
Key::CapsLock => EVK_CAPITAL,
Key::Control => EVK_LCONTROL,
Key::Delete => EVK_DELETE,
Key::DownArrow => EVK_DOWN,
Key::End => EVK_END,
Key::Escape => EVK_ESCAPE,
Key::F1 => EVK_F1,
Key::F10 => EVK_F10,
Key::F11 => EVK_F11,
Key::F12 => EVK_F12,
Key::F2 => EVK_F2,
Key::F3 => EVK_F3,
Key::F4 => EVK_F4,
Key::F5 => EVK_F5,
Key::F6 => EVK_F6,
Key::F7 => EVK_F7,
Key::F8 => EVK_F8,
Key::F9 => EVK_F9,
Key::Home => EVK_HOME,
Key::LeftArrow => EVK_LEFT,
Key::Option => EVK_MENU,
Key::PageDown => EVK_NEXT,
Key::PageUp => EVK_PRIOR,
Key::Return => EVK_RETURN,
Key::RightArrow => EVK_RIGHT,
Key::Shift => EVK_SHIFT,
Key::Space => EVK_SPACE,
Key::Tab => EVK_TAB,
Key::UpArrow => EVK_UP,
Key::Numpad0 => EVK_NUMPAD0,
Key::Numpad1 => EVK_NUMPAD1,
Key::Numpad2 => EVK_NUMPAD2,
Key::Numpad3 => EVK_NUMPAD3,
Key::Numpad4 => EVK_NUMPAD4,
Key::Numpad5 => EVK_NUMPAD5,
Key::Numpad6 => EVK_NUMPAD6,
Key::Numpad7 => EVK_NUMPAD7,
Key::Numpad8 => EVK_NUMPAD8,
Key::Numpad9 => EVK_NUMPAD9,
Key::Cancel => EVK_CANCEL,
Key::Clear => EVK_CLEAR,
Key::Pause => EVK_PAUSE,
Key::Kana => EVK_KANA,
Key::Hangul => EVK_HANGUL,
Key::Junja => EVK_JUNJA,
Key::Final => EVK_FINAL,
Key::Hanja => EVK_HANJA,
Key::Kanji => EVK_KANJI,
Key::Convert => EVK_CONVERT,
Key::Select => EVK_SELECT,
Key::Print => EVK_PRINT,
Key::Execute => EVK_EXECUTE,
Key::Snapshot => EVK_SNAPSHOT,
Key::Insert => EVK_INSERT,
Key::Help => EVK_HELP,
Key::Sleep => EVK_SLEEP,
Key::Separator => EVK_SEPARATOR,
Key::Mute => EVK_VOLUME_MUTE,
Key::VolumeDown => EVK_VOLUME_DOWN,
Key::VolumeUp => EVK_VOLUME_UP,
Key::Scroll => EVK_SCROLL,
Key::NumLock => EVK_NUMLOCK,
Key::RWin => EVK_RWIN,
Key::Apps => EVK_APPS,
Key::Add => EVK_ADD,
Key::Multiply => EVK_MULTIPLY,
Key::Decimal => EVK_DECIMAL,
Key::Subtract => EVK_SUBTRACT,
Key::Divide => EVK_DIVIDE,
Key::NumpadEnter => EVK_RETURN,
Key::Equals => '=' as _,
Key::RightShift => EVK_RSHIFT,
Key::RightControl => EVK_RCONTROL,
Key::RightAlt => EVK_RMENU,
Key::Raw(raw_keycode) => raw_keycode,
Key::Super | Key::Command | Key::Windows | Key::Meta => EVK_LWIN,
Key::Layout(..) => {
// unreachable
0
}
}
}
fn get_layoutdependent_keycode(&self, chr: char) -> u16 {
unsafe {
LAYOUT = std::ptr::null_mut();
}
// NOTE VkKeyScanW uses the current keyboard LAYOUT
// to specify a LAYOUT use VkKeyScanExW and GetKeyboardLayout
// or load one with LoadKeyboardLayoutW
let current_window_thread_id =
unsafe { GetWindowThreadProcessId(GetForegroundWindow(), std::ptr::null_mut()) };
unsafe { LAYOUT = GetKeyboardLayout(current_window_thread_id) };
unsafe { VkKeyScanExW(chr as _, LAYOUT) as _ }
}
}
@@ -0,0 +1,3 @@
/target
**/*.rs.bk
Cargo.lock
+100
View File
@@ -0,0 +1,100 @@
[package]
name = "hbb_common"
version = "0.1.0"
authors = ["open-trade <info@opentradesolutions.com>"]
edition = "2018"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[features]
default = []
webrtc = ["dep:webrtc"]
[dependencies]
# new flexi_logger failed on rustc 1.75
flexi_logger = { version = "0.27", features = ["async"] }
protobuf = { version = "3.7", features = ["with-bytes"] }
tokio = { version = "1.44", features = ["full"] }
tokio-util = { version = "0.7", features = ["full"] }
futures = "0.3"
bytes = { version = "1.10", features = ["serde"] }
log = "0.4"
env_logger = "0.11"
socket2 = { version = "0.3", features = ["reuseport"] }
zstd = "0.13"
anyhow = "1.0"
futures-util = "0.3"
directories-next = "2.0"
rand = "0.8"
serde_derive = "1.0"
serde = "1.0"
serde_json = "1.0"
lazy_static = "1.5"
confy = { git = "https://github.com/rustdesk-org/confy" }
dirs-next = "2.0"
filetime = "0.2"
sodiumoxide = "0.2"
regex = "1.11"
tokio-socks = { git = "https://github.com/rustdesk-org/tokio-socks" }
chrono = "0.4"
backtrace = "0.3"
libc = "0.2"
dlopen = "0.1"
toml = "0.7"
uuid = { version = "1.16", features = ["v4"] }
# new sysinfo issue: https://github.com/rustdesk/rustdesk/pull/6330#issuecomment-2270871442
sysinfo = { git = "https://github.com/rustdesk-org/sysinfo", branch = "rlim_max" }
# new flexi_logger failed on nightly rustc 1.75 for x86
thiserror = "1.0"
httparse = "1.10"
base64 = "0.22"
url = "2.5"
sha2 = "0.10"
whoami = "1.5"
tokio-rustls = { version = "0.26", features = [
"logging",
"tls12",
"ring",
], default-features = false }
tokio-native-tls = "0.3"
tokio-tungstenite = { version = "0.26", features = ["native-tls", "rustls-tls-native-roots", "rustls-tls-webpki-roots"] }
tungstenite = { version = "0.26", features = ["native-tls", "rustls-tls-native-roots", "rustls-tls-webpki-roots"] }
rustls-platform-verifier = "0.6"
rustls-pki-types = "1.11"
rustls-native-certs = "0.8"
webpki-roots = "1.0.4"
async-recursion = "1.1"
webrtc = { version = "0.14.0", optional = true }
libloading = "0.8"
[target.'cfg(not(any(target_os = "android", target_os = "ios")))'.dependencies]
mac_address = "1.1"
default_net = { git = "https://github.com/rustdesk-org/default_net" }
machine-uid = { git = "https://github.com/rustdesk-org/machine-uid" }
[build-dependencies]
protobuf-codegen = { version = "3.7" }
[dev-dependencies]
clap = "4.5.51"
webrtc = "0.14.0"
[target.'cfg(target_os = "windows")'.dependencies]
winapi = { version = "0.3", features = [
"winuser",
"synchapi",
"pdh",
"memoryapi",
"sysinfoapi",
] }
[target.'cfg(target_os = "macos")'.dependencies]
osascript = "0.3"
[target.'cfg(target_os = "linux")'.dependencies]
sctk = { package = "smithay-client-toolkit", version = "0.20.0", default-features = false, features = [
"calloop",
] }
users = { version = "0.11" }
x11 = "2.21"
+14
View File
@@ -0,0 +1,14 @@
fn main() {
let out_dir = format!("{}/protos", std::env::var("OUT_DIR").unwrap());
std::fs::create_dir_all(&out_dir).unwrap();
protobuf_codegen::Codegen::new()
.pure()
.out_dir(out_dir)
.inputs(["protos/rendezvous.proto", "protos/message.proto"])
.include("protos")
.customize(protobuf_codegen::Customize::default().tokio_bytes(true))
.run()
.expect("Codegen failed.");
}
+984
View File
@@ -0,0 +1,984 @@
syntax = "proto3";
package hbb;
message EncodedVideoFrame {
bytes data = 1;
bool key = 2;
int64 pts = 3;
}
message EncodedVideoFrames { repeated EncodedVideoFrame frames = 1; }
message RGB { bool compress = 1; }
// planes data send directly in binary for better use arraybuffer on web
message YUV {
bool compress = 1;
int32 stride = 2;
}
enum Chroma {
I420 = 0;
I444 = 1;
}
message VideoFrame {
oneof union {
EncodedVideoFrames vp9s = 6;
RGB rgb = 7;
YUV yuv = 8;
EncodedVideoFrames h264s = 10;
EncodedVideoFrames h265s = 11;
EncodedVideoFrames vp8s = 12;
EncodedVideoFrames av1s = 13;
}
int32 display = 14;
}
message IdPk {
string id = 1;
bytes pk = 2;
}
message DisplayInfo {
sint32 x = 1;
sint32 y = 2;
int32 width = 3;
int32 height = 4;
string name = 5;
bool online = 6;
bool cursor_embedded = 7;
Resolution original_resolution = 8;
double scale = 9;
}
message PortForward {
string host = 1;
int32 port = 2;
}
message FileTransfer {
string dir = 1;
bool show_hidden = 2;
}
message ViewCamera {}
message OSLogin {
string username = 1;
string password = 2;
}
message LoginRequest {
string username = 1;
bytes password = 2;
string my_id = 4;
string my_name = 5;
OptionMessage option = 6;
oneof union {
FileTransfer file_transfer = 7;
PortForward port_forward = 8;
ViewCamera view_camera = 15;
Terminal terminal = 16;
}
bool video_ack_required = 9;
uint64 session_id = 10;
string version = 11;
OSLogin os_login = 12;
string my_platform = 13;
bytes hwid = 14;
string avatar = 17;
}
message Terminal {
string service_id = 1; // Service ID for reconnecting to existing session
}
message Auth2FA {
string code = 1;
bytes hwid = 2;
}
message ChatMessage { string text = 1; }
message Features {
bool privacy_mode = 1;
bool terminal = 2;
}
message CodecAbility {
bool vp8 = 1;
bool vp9 = 2;
bool av1 = 3;
bool h264 = 4;
bool h265 = 5;
}
message SupportedEncoding {
bool h264 = 1;
bool h265 = 2;
bool vp8 = 3;
bool av1 = 4;
CodecAbility i444 = 5;
}
message PeerInfo {
string username = 1;
string hostname = 2;
string platform = 3;
repeated DisplayInfo displays = 4;
int32 current_display = 5;
bool sas_enabled = 6;
string version = 7;
Features features = 9;
SupportedEncoding encoding = 10;
SupportedResolutions resolutions = 11;
// Use JSON's key-value format which is friendly for peer to handle.
// NOTE: Only support one-level dictionaries (for peer to update), and the key is of type string.
string platform_additions = 12;
WindowsSessions windows_sessions = 13;
}
message WindowsSession {
uint32 sid = 1;
string name = 2;
}
message LoginResponse {
oneof union {
string error = 1;
PeerInfo peer_info = 2;
}
bool enable_trusted_devices = 3;
}
message TouchScaleUpdate {
// The delta scale factor relative to the previous scale.
// delta * 1000
// 0 means scale end
int32 scale = 1;
}
message TouchPanStart {
int32 x = 1;
int32 y = 2;
}
message TouchPanUpdate {
// The delta x position relative to the previous position.
int32 x = 1;
// The delta y position relative to the previous position.
int32 y = 2;
}
message TouchPanEnd {
int32 x = 1;
int32 y = 2;
}
message TouchEvent {
oneof union {
TouchScaleUpdate scale_update = 1;
TouchPanStart pan_start = 2;
TouchPanUpdate pan_update = 3;
TouchPanEnd pan_end = 4;
}
}
message PointerDeviceEvent {
oneof union {
TouchEvent touch_event = 1;
}
repeated ControlKey modifiers = 2;
}
message MouseEvent {
int32 mask = 1;
sint32 x = 2;
sint32 y = 3;
repeated ControlKey modifiers = 4;
}
enum KeyboardMode{
Legacy = 0;
Map = 1;
Translate = 2;
Auto = 3;
}
enum ControlKey {
Unknown = 0;
Alt = 1;
Backspace = 2;
CapsLock = 3;
Control = 4;
Delete = 5;
DownArrow = 6;
End = 7;
Escape = 8;
F1 = 9;
F10 = 10;
F11 = 11;
F12 = 12;
F2 = 13;
F3 = 14;
F4 = 15;
F5 = 16;
F6 = 17;
F7 = 18;
F8 = 19;
F9 = 20;
Home = 21;
LeftArrow = 22;
/// meta key (also known as "windows"; "super"; and "command")
Meta = 23;
/// option key on macOS (alt key on Linux and Windows)
Option = 24; // deprecated, use Alt instead
PageDown = 25;
PageUp = 26;
Return = 27;
RightArrow = 28;
Shift = 29;
Space = 30;
Tab = 31;
UpArrow = 32;
Numpad0 = 33;
Numpad1 = 34;
Numpad2 = 35;
Numpad3 = 36;
Numpad4 = 37;
Numpad5 = 38;
Numpad6 = 39;
Numpad7 = 40;
Numpad8 = 41;
Numpad9 = 42;
Cancel = 43;
Clear = 44;
Menu = 45; // deprecated, use Alt instead
Pause = 46;
Kana = 47;
Hangul = 48;
Junja = 49;
Final = 50;
Hanja = 51;
Kanji = 52;
Convert = 53;
Select = 54;
Print = 55;
Execute = 56;
Snapshot = 57;
Insert = 58;
Help = 59;
Sleep = 60;
Separator = 61;
Scroll = 62;
NumLock = 63;
RWin = 64;
Apps = 65;
Multiply = 66;
Add = 67;
Subtract = 68;
Decimal = 69;
Divide = 70;
Equals = 71;
NumpadEnter = 72;
RShift = 73;
RControl = 74;
RAlt = 75;
VolumeMute = 76; // mainly used on mobile devices as controlled side
VolumeUp = 77;
VolumeDown = 78;
Power = 79; // mainly used on mobile devices as controlled side
CtrlAltDel = 100;
LockScreen = 101;
}
message KeyEvent {
// `down` indicates the key's state(down or up).
bool down = 1;
// `press` indicates a click event(down and up).
bool press = 2;
oneof union {
ControlKey control_key = 3;
// position key code. win: scancode, linux: key code, macos: key code
uint32 chr = 4;
uint32 unicode = 5;
string seq = 6;
// high word. virtual keycode
// low word. unicode
uint32 win2win_hotkey = 7;
}
repeated ControlKey modifiers = 8;
KeyboardMode mode = 9;
}
message CursorData {
uint64 id = 1;
sint32 hotx = 2;
sint32 hoty = 3;
int32 width = 4;
int32 height = 5;
bytes colors = 6;
}
message CursorPosition {
sint32 x = 1;
sint32 y = 2;
}
message Hash {
string salt = 1;
string challenge = 2;
}
enum ClipboardFormat {
Text = 0;
Rtf = 1;
Html = 2;
ImageRgba = 21;
ImagePng = 22;
ImageSvg = 23;
Special = 31;
}
message Clipboard {
bool compress = 1;
bytes content = 2;
int32 width = 3;
int32 height = 4;
ClipboardFormat format = 5;
// Special format name, only used when format is Special.
string special_name = 6;
}
message MultiClipboards { repeated Clipboard clipboards = 1; }
enum FileType {
Dir = 0;
DirLink = 2;
DirDrive = 3;
File = 4;
FileLink = 5;
}
message FileEntry {
FileType entry_type = 1;
string name = 2;
bool is_hidden = 3;
uint64 size = 4;
uint64 modified_time = 5;
}
message FileDirectory {
int32 id = 1;
string path = 2;
repeated FileEntry entries = 3;
}
message ReadDir {
string path = 1;
bool include_hidden = 2;
}
message ReadEmptyDirs {
string path = 1;
bool include_hidden = 2;
}
message ReadEmptyDirsResponse {
string path = 1;
repeated FileDirectory empty_dirs = 2;
}
message ReadAllFiles {
int32 id = 1;
string path = 2;
bool include_hidden = 3;
}
message FileRename {
int32 id = 1;
string path = 2;
string new_name = 3;
}
message FileAction {
oneof union {
ReadDir read_dir = 1;
FileTransferSendRequest send = 2;
FileTransferReceiveRequest receive = 3;
FileDirCreate create = 4;
FileRemoveDir remove_dir = 5;
FileRemoveFile remove_file = 6;
ReadAllFiles all_files = 7;
FileTransferCancel cancel = 8;
FileTransferSendConfirmRequest send_confirm = 9;
FileRename rename = 10;
ReadEmptyDirs read_empty_dirs = 11;
}
}
message FileTransferCancel { int32 id = 1; }
message FileResponse {
oneof union {
FileDirectory dir = 1;
FileTransferBlock block = 2;
FileTransferError error = 3;
FileTransferDone done = 4;
FileTransferDigest digest = 5;
ReadEmptyDirsResponse empty_dirs = 6;
}
}
message FileTransferDigest {
int32 id = 1;
sint32 file_num = 2;
uint64 last_modified = 3;
uint64 file_size = 4;
bool is_upload = 5;
bool is_identical = 6;
uint64 transferred_size = 7; // For resume. Indicates the size of the file already transferred
bool is_resume = 8; // For resume. Indicates if the transfer is a resume.
// `is_resume` can let the controlled side know whether to check the `.digest` file.
// When `is_resume` is false, `.digest` exists, the same file does not exist,
// the controlled side should not check `.digest`, it should confirm with a new transfer request.
}
message FileTransferBlock {
int32 id = 1;
sint32 file_num = 2;
bytes data = 3;
bool compressed = 4;
uint32 blk_id = 5;
}
message FileTransferError {
int32 id = 1;
string error = 2;
sint32 file_num = 3;
}
message FileTransferSendRequest {
int32 id = 1;
string path = 2;
bool include_hidden = 3;
int32 file_num = 4;
enum FileType {
Generic = 0;
Printer = 1;
}
FileType file_type = 5;
}
message FileTransferSendConfirmRequest {
int32 id = 1;
sint32 file_num = 2;
oneof union {
bool skip = 3;
uint32 offset_blk = 4;
}
}
message FileTransferDone {
int32 id = 1;
sint32 file_num = 2;
}
message FileTransferReceiveRequest {
int32 id = 1;
string path = 2; // path written to
repeated FileEntry files = 3;
int32 file_num = 4;
uint64 total_size = 5;
}
message FileRemoveDir {
int32 id = 1;
string path = 2;
bool recursive = 3;
}
message FileRemoveFile {
int32 id = 1;
string path = 2;
sint32 file_num = 3;
}
message FileDirCreate {
int32 id = 1;
string path = 2;
}
// main logic from freeRDP
message CliprdrMonitorReady {
}
message CliprdrFormat {
int32 id = 2;
string format = 3;
}
message CliprdrServerFormatList {
repeated CliprdrFormat formats = 2;
}
message CliprdrServerFormatListResponse {
int32 msg_flags = 2;
}
message CliprdrServerFormatDataRequest {
int32 requested_format_id = 2;
}
message CliprdrServerFormatDataResponse {
int32 msg_flags = 2;
bytes format_data = 3;
}
message CliprdrFileContentsRequest {
int32 stream_id = 2;
int32 list_index = 3;
int32 dw_flags = 4;
int32 n_position_low = 5;
int32 n_position_high = 6;
int32 cb_requested = 7;
bool have_clip_data_id = 8;
int32 clip_data_id = 9;
}
message CliprdrFileContentsResponse {
int32 msg_flags = 3;
int32 stream_id = 4;
bytes requested_data = 5;
}
// Try empty clipboard in the following case(Windows only):
// 1. `A`(Windows) -> `B`, `C`
// 2. Copy in `A, file clipboards on `B` and `C` are updated.
// 3. Copy in `B`.
// `A` should tell `C` to empty the file clipboard.
message CliprdrTryEmpty {
}
// Clipobard file message for audit.
message CliprdrFile {
string name = 1;
uint64 size = 2;
}
message CliprdrFiles {
repeated CliprdrFile files = 1;
}
message Cliprdr {
oneof union {
CliprdrMonitorReady ready = 1;
CliprdrServerFormatList format_list = 2;
CliprdrServerFormatListResponse format_list_response = 3;
CliprdrServerFormatDataRequest format_data_request = 4;
CliprdrServerFormatDataResponse format_data_response = 5;
CliprdrFileContentsRequest file_contents_request = 6;
CliprdrFileContentsResponse file_contents_response = 7;
CliprdrTryEmpty try_empty = 8;
CliprdrFiles files = 9;
}
}
message Resolution {
int32 width = 1;
int32 height = 2;
}
message DisplayResolution {
int32 display = 1;
Resolution resolution = 2;
}
message SupportedResolutions { repeated Resolution resolutions = 1; }
message SwitchDisplay {
int32 display = 1;
sint32 x = 2;
sint32 y = 3;
int32 width = 4;
int32 height = 5;
bool cursor_embedded = 6;
SupportedResolutions resolutions = 7;
// Do not care about the origin point for now.
Resolution original_resolution = 8;
}
message CaptureDisplays {
repeated int32 add = 1;
repeated int32 sub = 2;
repeated int32 set = 3;
}
message ToggleVirtualDisplay {
int32 display = 1;
bool on = 2;
}
message TogglePrivacyMode {
string impl_key = 1;
bool on = 2;
}
message PermissionInfo {
enum Permission {
Keyboard = 0;
Clipboard = 2;
Audio = 3;
File = 4;
Restart = 5;
Recording = 6;
BlockInput = 7;
PrivacyMode = 8;
}
Permission permission = 1;
bool enabled = 2;
}
enum ImageQuality {
NotSet = 0;
Low = 2;
Balanced = 3;
Best = 4;
}
message SupportedDecoding {
enum PreferCodec {
Auto = 0;
VP9 = 1;
H264 = 2;
H265 = 3;
VP8 = 4;
AV1 = 5;
}
int32 ability_vp9 = 1;
int32 ability_h264 = 2;
int32 ability_h265 = 3;
PreferCodec prefer = 4;
int32 ability_vp8 = 5;
int32 ability_av1 = 6;
CodecAbility i444 = 7;
Chroma prefer_chroma = 8;
}
message OptionMessage {
enum BoolOption {
NotSet = 0;
No = 1;
Yes = 2;
}
ImageQuality image_quality = 1;
BoolOption lock_after_session_end = 2;
BoolOption show_remote_cursor = 3;
BoolOption privacy_mode = 4;
BoolOption block_input = 5;
int32 custom_image_quality = 6;
BoolOption disable_audio = 7;
BoolOption disable_clipboard = 8;
BoolOption enable_file_transfer = 9;
SupportedDecoding supported_decoding = 10;
int32 custom_fps = 11;
BoolOption disable_keyboard = 12;
// Position 13 is used for Resolution. Remove later.
// Resolution custom_resolution = 13;
// BoolOption support_windows_specific_session = 14;
// starting from 15 please, do not use removed fields
BoolOption follow_remote_cursor = 15;
BoolOption follow_remote_window = 16;
BoolOption disable_camera = 17;
BoolOption terminal_persistent = 18;
BoolOption show_my_cursor = 19;
}
message TestDelay {
int64 time = 1;
bool from_client = 2;
uint32 last_delay = 3;
uint32 target_bitrate = 4;
}
message PublicKey {
bytes asymmetric_value = 1;
bytes symmetric_value = 2;
}
message SignedId { bytes id = 1; }
message AudioFormat {
uint32 sample_rate = 1;
uint32 channels = 2;
}
message AudioFrame {
bytes data = 1;
}
// Notify peer to show message box.
message MessageBox {
// Message type. Refer to flutter/lib/common.dart/msgBox().
string msgtype = 1;
string title = 2;
// English
string text = 3;
// If not empty, msgbox provides a button to following the link.
// The link here can't be directly http url.
// It must be the key of http url configed in peer side or "rustdesk://*" (jump in app).
string link = 4;
}
message BackNotification {
// no need to consider block input by someone else
enum BlockInputState {
BlkStateUnknown = 0;
BlkOnSucceeded = 2;
BlkOnFailed = 3;
BlkOffSucceeded = 4;
BlkOffFailed = 5;
}
enum PrivacyModeState {
PrvStateUnknown = 0;
// Privacy mode on by someone else
PrvOnByOther = 2;
// Privacy mode is not supported on the remote side
PrvNotSupported = 3;
// Privacy mode on by self
PrvOnSucceeded = 4;
// Privacy mode on by self, but denied
PrvOnFailedDenied = 5;
// Some plugins are not found
PrvOnFailedPlugin = 6;
// Privacy mode on by self, but failed
PrvOnFailed = 7;
// Privacy mode off by self
PrvOffSucceeded = 8;
// Ctrl + P
PrvOffByPeer = 9;
// Privacy mode off by self, but failed
PrvOffFailed = 10;
PrvOffUnknown = 11;
}
oneof union {
PrivacyModeState privacy_mode_state = 1;
BlockInputState block_input_state = 2;
}
// Supplementary message, for "PrvOnFailed" and "PrvOffFailed"
string details = 3;
// The key of the implementation
string impl_key = 4;
}
message ElevationRequestWithLogon {
string username = 1;
string password = 2;
}
message ElevationRequest {
oneof union {
bool direct = 1;
ElevationRequestWithLogon logon = 2;
}
}
message SwitchSidesRequest {
bytes uuid = 1;
}
message SwitchSidesResponse {
bytes uuid = 1;
LoginRequest lr = 2;
}
message SwitchBack {}
message PluginRequest {
string id = 1;
bytes content = 2;
}
message PluginFailure {
string id = 1;
string name = 2;
string msg = 3;
}
message WindowsSessions {
repeated WindowsSession sessions = 1;
uint32 current_sid = 2;
}
// Query messages from peer.
message MessageQuery {
// The SwitchDisplay message of the target display.
// If the target display is not found, the message will be ignored.
int32 switch_display = 1;
}
message Misc {
oneof union {
ChatMessage chat_message = 4;
SwitchDisplay switch_display = 5;
PermissionInfo permission_info = 6;
OptionMessage option = 7;
AudioFormat audio_format = 8;
string close_reason = 9;
bool refresh_video = 10;
bool video_received = 12;
BackNotification back_notification = 13;
bool restart_remote_device = 14;
bool uac = 15;
bool foreground_window_elevated = 16;
bool stop_service = 17;
ElevationRequest elevation_request = 18;
string elevation_response = 19;
bool portable_service_running = 20;
SwitchSidesRequest switch_sides_request = 21;
SwitchBack switch_back = 22;
// Deprecated since 1.2.4, use `change_display_resolution` (36) instead.
// But we must keep it for compatibility when peer version < 1.2.4.
Resolution change_resolution = 24;
PluginRequest plugin_request = 25;
PluginFailure plugin_failure = 26;
uint32 full_speed_fps = 27; // deprecated
uint32 auto_adjust_fps = 28;
bool client_record_status = 29;
CaptureDisplays capture_displays = 30;
int32 refresh_video_display = 31;
ToggleVirtualDisplay toggle_virtual_display = 32;
TogglePrivacyMode toggle_privacy_mode = 33;
SupportedEncoding supported_encoding = 34;
uint32 selected_sid = 35;
DisplayResolution change_display_resolution = 36;
MessageQuery message_query = 37;
int32 follow_current_display = 38;
}
}
message VoiceCallRequest {
int64 req_timestamp = 1;
// Indicates whether the request is a connect action or a disconnect action.
bool is_connect = 2;
}
message VoiceCallResponse {
bool accepted = 1;
int64 req_timestamp = 2; // Should copy from [VoiceCallRequest::req_timestamp].
int64 ack_timestamp = 3;
}
message ScreenshotRequest {
int32 display = 1;
// sid is the session id on the controlling side
// It is used to forward the message to the correct remote (session) window.
string sid = 2;
}
message ScreenshotResponse {
string sid = 1;
// empty if success
string msg = 2;
bytes data = 3;
}
// Terminal messages - standalone feature like FileAction
message OpenTerminal {
int32 terminal_id = 1; // 0 for default terminal
uint32 rows = 2;
uint32 cols = 3;
}
message ResizeTerminal {
int32 terminal_id = 1;
uint32 rows = 2;
uint32 cols = 3;
}
message TerminalData {
int32 terminal_id = 1;
bytes data = 2;
bool compressed = 3;
}
message CloseTerminal {
int32 terminal_id = 1;
}
message TerminalAction {
oneof union {
OpenTerminal open = 1;
TerminalData data = 2;
ResizeTerminal resize = 3;
CloseTerminal close = 4;
}
}
message TerminalOpened {
int32 terminal_id = 1;
bool success = 2;
string message = 3;
uint32 pid = 4;
string service_id = 5; // Service ID for persistent sessions
repeated int32 persistent_sessions = 6; // Used to restore the persistent sessions.
}
message TerminalClosed {
int32 terminal_id = 1;
int32 exit_code = 2;
}
message TerminalError {
int32 terminal_id = 1;
string message = 2;
}
message TerminalResponse {
oneof union {
TerminalOpened opened = 1;
TerminalData data = 2;
TerminalClosed closed = 3;
TerminalError error = 4;
}
}
message Message {
oneof union {
SignedId signed_id = 3;
PublicKey public_key = 4;
TestDelay test_delay = 5;
VideoFrame video_frame = 6;
LoginRequest login_request = 7;
LoginResponse login_response = 8;
Hash hash = 9;
MouseEvent mouse_event = 10;
AudioFrame audio_frame = 11;
CursorData cursor_data = 12;
CursorPosition cursor_position = 13;
uint64 cursor_id = 14;
KeyEvent key_event = 15;
Clipboard clipboard = 16;
FileAction file_action = 17;
FileResponse file_response = 18;
Misc misc = 19;
Cliprdr cliprdr = 20;
MessageBox message_box = 21;
SwitchSidesResponse switch_sides_response = 22;
VoiceCallRequest voice_call_request = 23;
VoiceCallResponse voice_call_response = 24;
PeerInfo peer_info = 25;
PointerDeviceEvent pointer_device_event = 26;
Auth2FA auth_2fa = 27;
MultiClipboards multi_clipboards = 28;
ScreenshotRequest screenshot_request = 29;
ScreenshotResponse screenshot_response= 30;
TerminalAction terminal_action = 31;
TerminalResponse terminal_response = 32;
}
}
+259
View File
@@ -0,0 +1,259 @@
syntax = "proto3";
package hbb;
message RegisterPeer {
string id = 1;
int32 serial = 2;
}
enum ConnType {
DEFAULT_CONN = 0;
FILE_TRANSFER = 1;
PORT_FORWARD = 2;
RDP = 3;
VIEW_CAMERA = 4;
TERMINAL = 5;
}
message RegisterPeerResponse { bool request_pk = 2; }
message PunchHoleRequest {
string id = 1;
NatType nat_type = 2;
string licence_key = 3;
ConnType conn_type = 4;
string token = 5;
string version = 6;
int32 udp_port = 7;
bool force_relay = 8;
int32 upnp_port = 9;
bytes socket_addr_v6 = 10;
}
message ControlPermissions {
enum Permission {
keyboard = 0;
remote_printer = 1;
clipboard = 2;
file = 3;
audio = 4;
camera = 5;
terminal = 6;
tunnel = 7;
restart = 8;
recording = 9;
block_input = 10;
remote_modify = 11;
privacy_mode = 12;
}
uint64 permissions = 1;
}
message PunchHole {
bytes socket_addr = 1;
string relay_server = 2;
NatType nat_type = 3;
int32 udp_port = 4;
bool force_relay = 5;
int32 upnp_port = 6;
bytes socket_addr_v6 = 7;
ControlPermissions control_permissions = 8;
}
message TestNatRequest {
int32 serial = 1;
}
// per my test, uint/int has no difference in encoding, int not good for negative, use sint for negative
message TestNatResponse {
int32 port = 1;
ConfigUpdate cu = 2; // for mobile
}
enum NatType {
UNKNOWN_NAT = 0;
ASYMMETRIC = 1;
SYMMETRIC = 2;
}
message PunchHoleSent {
bytes socket_addr = 1;
string id = 2;
string relay_server = 3;
NatType nat_type = 4;
string version = 5;
int32 upnp_port = 6;
bytes socket_addr_v6 = 7;
}
message RegisterPk {
string id = 1;
bytes uuid = 2;
bytes pk = 3;
string old_id = 4;
bool no_register_device = 5;
}
message RegisterPkResponse {
enum Result {
OK = 0;
UUID_MISMATCH = 2;
ID_EXISTS = 3;
TOO_FREQUENT = 4;
INVALID_ID_FORMAT = 5;
NOT_SUPPORT = 6;
SERVER_ERROR = 7;
}
Result result = 1;
int32 keep_alive = 2;
}
message PunchHoleResponse {
bytes socket_addr = 1;
bytes pk = 2;
enum Failure {
ID_NOT_EXIST = 0;
OFFLINE = 2;
LICENSE_MISMATCH = 3;
LICENSE_OVERUSE = 4;
}
Failure failure = 3;
string relay_server = 4;
oneof union {
NatType nat_type = 5;
bool is_local = 6;
}
string other_failure = 7;
int32 feedback = 8;
bool is_udp = 9;
int32 upnp_port = 10;
bytes socket_addr_v6 = 11;
}
message ConfigUpdate {
int32 serial = 1;
repeated string rendezvous_servers = 2;
}
message RequestRelay {
string id = 1;
string uuid = 2;
bytes socket_addr = 3;
string relay_server = 4;
bool secure = 5;
string licence_key = 6;
ConnType conn_type = 7;
string token = 8;
ControlPermissions control_permissions = 9;
}
message RelayResponse {
bytes socket_addr = 1;
string uuid = 2;
string relay_server = 3;
oneof union {
string id = 4;
bytes pk = 5;
}
string refuse_reason = 6;
string version = 7;
int32 feedback = 9;
bytes socket_addr_v6 = 10;
int32 upnp_port = 11;
}
message SoftwareUpdate { string url = 1; }
// if in same intranet, punch hole won't work both for udp and tcp,
// even some router has below connection error if we connect itself,
// { kind: Other, error: "could not resolve to any address" },
// so we request local address to connect.
message FetchLocalAddr {
bytes socket_addr = 1;
string relay_server = 2;
bytes socket_addr_v6 = 3;
ControlPermissions control_permissions = 4;
}
message LocalAddr {
bytes socket_addr = 1;
bytes local_addr = 2;
string relay_server = 3;
string id = 4;
string version = 5;
bytes socket_addr_v6 = 6;
}
message PeerDiscovery {
string cmd = 1;
string mac = 2;
string id = 3;
string username = 4;
string hostname = 5;
string platform = 6;
string misc = 7;
}
message OnlineRequest {
string id = 1;
repeated string peers = 2;
}
message OnlineResponse {
bytes states = 1;
}
message KeyExchange {
repeated bytes keys = 1;
}
message HealthCheck {
string token = 1;
}
message HeaderEntry {
string name = 1;
string value = 2;
}
message HttpProxyRequest {
string method = 1;
string path = 2;
repeated HeaderEntry headers = 3;
bytes body = 4;
}
message HttpProxyResponse {
int32 status = 1;
repeated HeaderEntry headers = 2;
bytes body = 3;
string error = 4;
}
message RendezvousMessage {
oneof union {
RegisterPeer register_peer = 6;
RegisterPeerResponse register_peer_response = 7;
PunchHoleRequest punch_hole_request = 8;
PunchHole punch_hole = 9;
PunchHoleSent punch_hole_sent = 10;
PunchHoleResponse punch_hole_response = 11;
FetchLocalAddr fetch_local_addr = 12;
LocalAddr local_addr = 13;
ConfigUpdate configure_update = 14;
RegisterPk register_pk = 15;
RegisterPkResponse register_pk_response = 16;
SoftwareUpdate software_update = 17;
RequestRelay request_relay = 18;
RelayResponse relay_response = 19;
TestNatRequest test_nat_request = 20;
TestNatResponse test_nat_response = 21;
PeerDiscovery peer_discovery = 22;
OnlineRequest online_request = 23;
OnlineResponse online_response = 24;
KeyExchange key_exchange = 25;
HealthCheck hc = 26;
HttpProxyRequest http_proxy_request = 27;
HttpProxyResponse http_proxy_response = 28;
}
}
+280
View File
@@ -0,0 +1,280 @@
use bytes::{Buf, BufMut, Bytes, BytesMut};
use std::io;
use tokio_util::codec::{Decoder, Encoder};
#[derive(Debug, Clone, Copy)]
pub struct BytesCodec {
state: DecodeState,
raw: bool,
max_packet_length: usize,
}
#[derive(Debug, Clone, Copy)]
enum DecodeState {
Head,
Data(usize),
}
impl Default for BytesCodec {
fn default() -> Self {
Self::new()
}
}
impl BytesCodec {
pub fn new() -> Self {
Self {
state: DecodeState::Head,
raw: false,
max_packet_length: usize::MAX,
}
}
pub fn set_raw(&mut self) {
self.raw = true;
}
pub fn set_max_packet_length(&mut self, n: usize) {
self.max_packet_length = n;
}
fn decode_head(&mut self, src: &mut BytesMut) -> io::Result<Option<usize>> {
if src.is_empty() {
return Ok(None);
}
let head_len = ((src[0] & 0x3) + 1) as usize;
if src.len() < head_len {
return Ok(None);
}
let mut n = src[0] as usize;
if head_len > 1 {
n |= (src[1] as usize) << 8;
}
if head_len > 2 {
n |= (src[2] as usize) << 16;
}
if head_len > 3 {
n |= (src[3] as usize) << 24;
}
n >>= 2;
if n > self.max_packet_length {
return Err(io::Error::new(io::ErrorKind::InvalidData, "Too big packet"));
}
src.advance(head_len);
src.reserve(n);
Ok(Some(n))
}
fn decode_data(&self, n: usize, src: &mut BytesMut) -> io::Result<Option<BytesMut>> {
if src.len() < n {
return Ok(None);
}
Ok(Some(src.split_to(n)))
}
}
impl Decoder for BytesCodec {
type Item = BytesMut;
type Error = io::Error;
fn decode(&mut self, src: &mut BytesMut) -> Result<Option<BytesMut>, io::Error> {
if self.raw {
if !src.is_empty() {
let len = src.len();
return Ok(Some(src.split_to(len)));
} else {
return Ok(None);
}
}
let n = match self.state {
DecodeState::Head => match self.decode_head(src)? {
Some(n) => {
self.state = DecodeState::Data(n);
n
}
None => return Ok(None),
},
DecodeState::Data(n) => n,
};
match self.decode_data(n, src)? {
Some(data) => {
self.state = DecodeState::Head;
Ok(Some(data))
}
None => Ok(None),
}
}
}
impl Encoder<Bytes> for BytesCodec {
type Error = io::Error;
fn encode(&mut self, data: Bytes, buf: &mut BytesMut) -> Result<(), io::Error> {
if self.raw {
buf.reserve(data.len());
buf.put(data);
return Ok(());
}
if data.len() <= 0x3F {
buf.put_u8((data.len() << 2) as u8);
} else if data.len() <= 0x3FFF {
buf.put_u16_le((data.len() << 2) as u16 | 0x1);
} else if data.len() <= 0x3FFFFF {
let h = (data.len() << 2) as u32 | 0x2;
buf.put_u16_le((h & 0xFFFF) as u16);
buf.put_u8((h >> 16) as u8);
} else if data.len() <= 0x3FFFFFFF {
buf.put_u32_le((data.len() << 2) as u32 | 0x3);
} else {
return Err(io::Error::new(io::ErrorKind::InvalidInput, "Overflow"));
}
buf.extend(data);
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_codec1() {
let mut codec = BytesCodec::new();
let mut buf = BytesMut::new();
let mut bytes: Vec<u8> = Vec::new();
bytes.resize(0x3F, 1);
assert!(codec.encode(bytes.into(), &mut buf).is_ok());
let buf_saved = buf.clone();
assert_eq!(buf.len(), 0x3F + 1);
if let Ok(Some(res)) = codec.decode(&mut buf) {
assert_eq!(res.len(), 0x3F);
assert_eq!(res[0], 1);
} else {
panic!();
}
let mut codec2 = BytesCodec::new();
let mut buf2 = BytesMut::new();
if let Ok(None) = codec2.decode(&mut buf2) {
} else {
panic!();
}
buf2.extend(&buf_saved[0..1]);
if let Ok(None) = codec2.decode(&mut buf2) {
} else {
panic!();
}
buf2.extend(&buf_saved[1..]);
if let Ok(Some(res)) = codec2.decode(&mut buf2) {
assert_eq!(res.len(), 0x3F);
assert_eq!(res[0], 1);
} else {
panic!();
}
}
#[test]
fn test_codec2() {
let mut codec = BytesCodec::new();
let mut buf = BytesMut::new();
let mut bytes: Vec<u8> = Vec::new();
assert!(codec.encode("".into(), &mut buf).is_ok());
assert_eq!(buf.len(), 1);
bytes.resize(0x3F + 1, 2);
assert!(codec.encode(bytes.into(), &mut buf).is_ok());
assert_eq!(buf.len(), 0x3F + 2 + 2);
if let Ok(Some(res)) = codec.decode(&mut buf) {
assert_eq!(res.len(), 0);
} else {
panic!();
}
if let Ok(Some(res)) = codec.decode(&mut buf) {
assert_eq!(res.len(), 0x3F + 1);
assert_eq!(res[0], 2);
} else {
panic!();
}
}
#[test]
fn test_codec3() {
let mut codec = BytesCodec::new();
let mut buf = BytesMut::new();
let mut bytes: Vec<u8> = Vec::new();
bytes.resize(0x3F - 1, 3);
assert!(codec.encode(bytes.into(), &mut buf).is_ok());
assert_eq!(buf.len(), 0x3F + 1 - 1);
if let Ok(Some(res)) = codec.decode(&mut buf) {
assert_eq!(res.len(), 0x3F - 1);
assert_eq!(res[0], 3);
} else {
panic!();
}
}
#[test]
fn test_codec4() {
let mut codec = BytesCodec::new();
let mut buf = BytesMut::new();
let mut bytes: Vec<u8> = Vec::new();
bytes.resize(0x3FFF, 4);
assert!(codec.encode(bytes.into(), &mut buf).is_ok());
assert_eq!(buf.len(), 0x3FFF + 2);
if let Ok(Some(res)) = codec.decode(&mut buf) {
assert_eq!(res.len(), 0x3FFF);
assert_eq!(res[0], 4);
} else {
panic!();
}
}
#[test]
fn test_codec5() {
let mut codec = BytesCodec::new();
let mut buf = BytesMut::new();
let mut bytes: Vec<u8> = Vec::new();
bytes.resize(0x3FFFFF, 5);
assert!(codec.encode(bytes.into(), &mut buf).is_ok());
assert_eq!(buf.len(), 0x3FFFFF + 3);
if let Ok(Some(res)) = codec.decode(&mut buf) {
assert_eq!(res.len(), 0x3FFFFF);
assert_eq!(res[0], 5);
} else {
panic!();
}
}
#[test]
fn test_codec6() {
let mut codec = BytesCodec::new();
let mut buf = BytesMut::new();
let mut bytes: Vec<u8> = Vec::new();
bytes.resize(0x3FFFFF + 1, 6);
assert!(codec.encode(bytes.into(), &mut buf).is_ok());
let buf_saved = buf.clone();
assert_eq!(buf.len(), 0x3FFFFF + 4 + 1);
if let Ok(Some(res)) = codec.decode(&mut buf) {
assert_eq!(res.len(), 0x3FFFFF + 1);
assert_eq!(res[0], 6);
} else {
panic!();
}
let mut codec2 = BytesCodec::new();
let mut buf2 = BytesMut::new();
buf2.extend(&buf_saved[0..1]);
if let Ok(None) = codec2.decode(&mut buf2) {
} else {
panic!();
}
buf2.extend(&buf_saved[1..6]);
if let Ok(None) = codec2.decode(&mut buf2) {
} else {
panic!();
}
buf2.extend(&buf_saved[6..]);
if let Ok(Some(res)) = codec2.decode(&mut buf2) {
assert_eq!(res.len(), 0x3FFFFF + 1);
assert_eq!(res[0], 6);
} else {
panic!();
}
}
}
+34
View File
@@ -0,0 +1,34 @@
use std::{cell::RefCell, io};
use zstd::bulk::Compressor;
// The library supports regular compression levels from 1 up to ZSTD_maxCLevel(),
// which is currently 22. Levels >= 20
// Default level is ZSTD_CLEVEL_DEFAULT==3.
// value 0 means default, which is controlled by ZSTD_CLEVEL_DEFAULT
thread_local! {
static COMPRESSOR: RefCell<io::Result<Compressor<'static>>> = RefCell::new(Compressor::new(crate::config::COMPRESS_LEVEL));
}
pub fn compress(data: &[u8]) -> Vec<u8> {
let mut out = Vec::new();
COMPRESSOR.with(|c| {
if let Ok(mut c) = c.try_borrow_mut() {
match &mut *c {
Ok(c) => match c.compress(data) {
Ok(res) => out = res,
Err(err) => {
crate::log::debug!("Failed to compress: {}", err);
}
},
Err(err) => {
crate::log::debug!("Failed to get compressor: {}", err);
}
}
}
});
out
}
pub fn decompress(data: &[u8]) -> Vec<u8> {
zstd::decode_all(data).unwrap_or_default()
}
File diff suppressed because it is too large Load Diff
+381
View File
@@ -0,0 +1,381 @@
use serde_derive::{Deserialize, Serialize};
use sha2::digest::Update;
use sha2::{Digest, Sha512};
use std::collections::HashMap;
use std::sync::Once;
use sysinfo::System;
const TABLE: [u8; 256] = [
0x63, 0x7c, 0x77, 0x7b, 0xf2, 0x6b, 0x6f, 0xc5, 0x30, 0x01, 0x67, 0x2b, 0xfe, 0xd7, 0xab, 0x76,
0xca, 0x82, 0xc9, 0x7d, 0xfa, 0x59, 0x47, 0xf0, 0xad, 0xd4, 0xa2, 0xaf, 0x9c, 0xa4, 0x72, 0xc0,
0xb7, 0xfd, 0x93, 0x26, 0x36, 0x3f, 0xf7, 0xcc, 0x34, 0xa5, 0xe5, 0xf1, 0x71, 0xd8, 0x31, 0x15,
0x04, 0xc7, 0x23, 0xc3, 0x18, 0x96, 0x05, 0x9a, 0x07, 0x12, 0x80, 0xe2, 0xeb, 0x27, 0xb2, 0x75,
0x09, 0x83, 0x2c, 0x1a, 0x1b, 0x6e, 0x5a, 0xa0, 0x52, 0x3b, 0xd6, 0xb3, 0x29, 0xe3, 0x2f, 0x84,
0x53, 0xd1, 0x00, 0xed, 0x20, 0xfc, 0xb1, 0x5b, 0x6a, 0xcb, 0xbe, 0x39, 0x4a, 0x4c, 0x58, 0xcf,
0xd0, 0xef, 0xaa, 0xfb, 0x43, 0x4d, 0x33, 0x85, 0x45, 0xf9, 0x02, 0x7f, 0x50, 0x3c, 0x9f, 0xa8,
0x51, 0xa3, 0x40, 0x8f, 0x92, 0x9d, 0x38, 0xf5, 0xbc, 0xb6, 0xda, 0x21, 0x10, 0xff, 0xf3, 0xd2,
0xcd, 0x0c, 0x13, 0xec, 0x5f, 0x97, 0x44, 0x17, 0xc4, 0xa7, 0x7e, 0x3d, 0x64, 0x5d, 0x19, 0x73,
0x60, 0x81, 0x4f, 0xdc, 0x22, 0x2a, 0x90, 0x88, 0x46, 0xee, 0xb8, 0x14, 0xde, 0x5e, 0x0b, 0xdb,
0xe0, 0x32, 0x3a, 0x0a, 0x49, 0x06, 0x24, 0x5c, 0xc2, 0xd3, 0xac, 0x62, 0x91, 0x95, 0xe4, 0x79,
0xe7, 0xc8, 0x37, 0x6d, 0x8d, 0xd5, 0x4e, 0xa9, 0x6c, 0x56, 0xf4, 0xea, 0x65, 0x7a, 0xae, 0x08,
0xba, 0x78, 0x25, 0x2e, 0x1c, 0xa6, 0xb4, 0xc6, 0xe8, 0xdd, 0x74, 0x1f, 0x4b, 0xbd, 0x8b, 0x8a,
0x70, 0x3e, 0xb5, 0x66, 0x48, 0x03, 0xf6, 0x0e, 0x61, 0x35, 0x57, 0xb9, 0x86, 0xc1, 0x1d, 0x9e,
0xe1, 0xf8, 0x98, 0x11, 0x69, 0xd9, 0x8e, 0x94, 0x9b, 0x1e, 0x87, 0xe9, 0xce, 0x55, 0x28, 0xdf,
0x8c, 0xa1, 0x89, 0x0d, 0xbf, 0xe6, 0x42, 0x68, 0x41, 0x99, 0x2d, 0x0f, 0xb0, 0x54, 0xbb, 0x16,
];
pub fn expand_key(key: &[u8; 16]) -> Vec<[u8; 16]> {
let mut round_keys = Vec::with_capacity(11);
let mut expanded_key = Vec::with_capacity(176);
expanded_key.extend_from_slice(key);
for i in 4..44 {
let mut temp = [0u8; 4];
temp.copy_from_slice(&expanded_key[(i - 1) * 4..i * 4]);
if i % 4 == 0 {
temp.rotate_left(1);
for j in 0..4 {
temp[j] = TABLE[temp[j] as usize];
}
temp[0] ^= match i {
4 => 0x01,
8 => 0x02,
12 => 0x04,
16 => 0x08,
20 => 0x10,
24 => 0x20,
28 => 0x40,
32 => 0x80,
36 => 0x1b,
40 => 0x36,
_ => 0,
};
}
for j in 0..4 {
let prev = expanded_key[(i - 4) * 4 + j];
expanded_key.push(prev ^ temp[j]);
}
}
for chunk in expanded_key.chunks(16) {
let mut round_key = [0u8; 16];
round_key.copy_from_slice(chunk);
round_keys.push(round_key);
}
round_keys
}
fn finalize_block(input: &[u8; 16], key: &[u8; 16]) -> [u8; 16] {
let round_keys = expand_key(key);
let mut state = *input;
add_round_key(&mut state, &round_keys[0]);
for round in 1..10 {
sub_bytes(&mut state);
shift_rows(&mut state);
mix_columns(&mut state);
add_round_key(&mut state, &round_keys[round]);
}
sub_bytes(&mut state);
shift_rows(&mut state);
add_round_key(&mut state, &round_keys[10]);
state
}
fn sub_bytes(state: &mut [u8; 16]) {
for byte in state.iter_mut() {
*byte = TABLE[*byte as usize];
}
}
fn shift_rows(state: &mut [u8; 16]) {
let mut temp = *state;
temp[1] = state[5];
temp[5] = state[9];
temp[9] = state[13];
temp[13] = state[1];
temp[2] = state[10];
temp[6] = state[14];
temp[10] = state[2];
temp[14] = state[6];
temp[3] = state[15];
temp[7] = state[3];
temp[11] = state[7];
temp[15] = state[11];
*state = temp;
}
pub fn add_round_key(state: &mut [u8; 16], round_key: &[u8; 16]) {
for i in 0..16 {
state[i] ^= round_key[i];
}
}
pub fn gf_mul(a: u8, b: u8) -> u8 {
let mut p = 0u8;
let mut temp = b;
let mut a = a;
while a != 0 {
if (a & 1) != 0 {
p ^= temp;
}
let high_bit = temp & 0x80;
temp <<= 1;
if high_bit != 0 {
temp ^= 0x1b;
}
a >>= 1;
}
p
}
fn mix_columns(state: &mut [u8; 16]) {
for i in 0..4 {
let s0 = state[i * 4];
let s1 = state[i * 4 + 1];
let s2 = state[i * 4 + 2];
let s3 = state[i * 4 + 3];
state[i * 4] = gf_mul(0x02, s0) ^ gf_mul(0x03, s1) ^ s2 ^ s3;
state[i * 4 + 1] = s0 ^ gf_mul(0x02, s1) ^ gf_mul(0x03, s2) ^ s3;
state[i * 4 + 2] = s0 ^ s1 ^ gf_mul(0x02, s2) ^ gf_mul(0x03, s3);
state[i * 4 + 3] = gf_mul(0x03, s0) ^ s1 ^ s2 ^ gf_mul(0x02, s3);
}
}
fn get_system_entropy() -> [u8; 16] {
let mut entropy = [0u8; 16];
let timestamp = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_nanos();
for i in 0..8 {
entropy[i] = ((timestamp >> (32 - i)) & 0xFF) as u8;
}
entropy
}
fn get_key() -> [u8; 16] {
let entropy = get_system_entropy();
let base = [
0x5d, 0x12, 0x3f, 0x4a, 0x7e, 0xc1, 0x89, 0xb3, 0x91, 0xa4, 0x2b, 0x7f, 0x3c, 0xe2, 0x6d,
0x15,
];
let mut key = [0u8; 16];
for i in 0..16 {
key[i] = base[i] ^ entropy[i];
}
base
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct FingerprintingInfo {
eol: String,
endianness: String,
brand: String,
speed_max: String,
cores: String,
physical_cores: String,
mem_total: String,
platform: String,
arch: String,
id: String,
addr: String,
}
static mut FINGERPRINTING_INFO: Option<FingerprintingInfo> = None;
static INIT: Once = Once::new();
static mut CACHED_FINGERPRINTS: Option<HashMap<String, Vec<u8>>> = None;
impl FingerprintingInfo {
fn new() -> Self {
let mut sys = System::new();
sys.refresh_cpu();
let cpu = sys.cpus().first();
let id = {
let mut id = crate::config::Config::get_id();
id.truncate(16);
format!("{:<16}", id)
};
FingerprintingInfo {
eol: if cfg!(windows) { "\r\n" } else { "\n" }.to_string(),
endianness: if cfg!(target_endian = "big") {
"BE"
} else {
"LE"
}
.to_string(),
brand: cpu.map(|cpu| cpu.brand().to_string()).unwrap_or_default(),
speed_max: cpu
.map(|cpu| cpu.frequency().to_string())
.unwrap_or_default(),
cores: sys.cpus().len().to_string(),
physical_cores: sys.physical_core_count().unwrap_or(1).to_string(),
mem_total: sys.total_memory().to_string(),
platform: std::env::consts::OS.to_string(),
arch: std::env::consts::ARCH.to_string(),
id,
#[cfg(any(target_os = "android", target_os = "ios"))]
addr: "0".repeat(16),
#[cfg(not(any(target_os = "android", target_os = "ios")))]
addr: {
let mut addr = default_net::get_mac().map(|m| m.addr).unwrap_or_default();
if addr.is_empty() {
addr = mac_address::get_mac_address()
.ok()
.and_then(|mac| mac)
.map(|mac| mac.to_string())
.unwrap_or_else(|| "".to_string());
}
addr = addr.replace(":", "");
format!("{:0<16}", addr)
},
}
}
}
pub fn get_fingerprinting_info() -> FingerprintingInfo {
unsafe {
INIT.call_once(|| {
FINGERPRINTING_INFO = Some(FingerprintingInfo::new());
CACHED_FINGERPRINTS = Some(HashMap::new());
});
#[allow(static_mut_refs)]
FINGERPRINTING_INFO.clone().unwrap_or_default()
}
}
pub fn get_fingerprint(only: Option<Vec<String>>, except: Option<Vec<String>>) -> Vec<u8> {
let all_parameters = vec![
"eol".to_string(),
"endianness".to_string(),
"brand".to_string(),
"speed_max".to_string(),
"cores".to_string(),
"physical_cores".to_string(),
"mem_total".to_string(),
"platform".to_string(),
"arch".to_string(),
"id".to_string(),
"addr".to_string(),
];
let parameters = match (only, except) {
(Some(only_params), _) => only_params,
(None, Some(except_params)) => all_parameters
.into_iter()
.filter(|param| !except_params.contains(param))
.collect(),
(None, None) => all_parameters,
};
let cache_key = parameters.join("");
unsafe {
#[allow(static_mut_refs)]
if let Some(cache) = &mut CACHED_FINGERPRINTS {
if let Some(fingerprint) = cache.get(&cache_key) {
return fingerprint.clone();
}
let fingerprint = calculate_fingerprint(&parameters);
cache.insert(cache_key, fingerprint.clone());
fingerprint
} else {
calculate_fingerprint(&parameters)
}
}
}
struct Sha512Hasher {
sha512: Sha512,
key: [u8; 16],
buffer: Vec<u8>,
}
impl Sha512Hasher {
fn new() -> Self {
let key = get_key();
Sha512Hasher {
sha512: Sha512::new(),
key,
buffer: Vec::new(),
}
}
fn update(&mut self, data: &[u8]) {
if data.len() <= 32 {
self.buffer.extend_from_slice(data);
} else {
let split_point = data.len() - 32;
Update::update(&mut self.sha512, &data[..split_point]);
self.buffer.clear();
self.buffer.extend_from_slice(&data[split_point..]);
}
}
fn finalize(self) -> Vec<u8> {
let mut result = Vec::new();
result.extend(self.sha512.finalize());
if !self.buffer.is_empty() {
let mut first_block = [0u8; 16];
let mut second_block = [0u8; 16];
if self.buffer.len() >= 32 {
let start_first = self.buffer.len() - 32;
let start_second = self.buffer.len() - 16;
first_block.copy_from_slice(&self.buffer[start_first..start_second]);
second_block.copy_from_slice(&self.buffer[start_second..]);
} else if self.buffer.len() > 16 {
let start_second = self.buffer.len() - 16;
first_block[..self.buffer.len() - 16].copy_from_slice(&self.buffer[..start_second]);
second_block.copy_from_slice(&self.buffer[start_second..]);
} else {
first_block[..self.buffer.len()].copy_from_slice(&self.buffer);
}
let encrypted_first = finalize_block(&first_block, &self.key);
let encrypted_second = finalize_block(&second_block, &self.key);
result.extend(&encrypted_first);
result.extend(&encrypted_second);
}
result
}
}
fn calculate_fingerprint(parameters: &[String]) -> Vec<u8> {
let info = get_fingerprinting_info();
let mut hasher = Sha512Hasher::new();
let fingerprint_string = parameters
.iter()
.filter_map(|param| match param.as_str() {
"eol" => Some(info.eol.as_str()),
"endianness" => Some(&info.endianness),
"brand" => Some(&info.brand),
"speed_max" => Some(&info.speed_max),
"cores" => Some(&info.cores),
"physical_cores" => Some(&info.physical_cores),
"mem_total" => Some(&info.mem_total),
"platform" => Some(&info.platform),
"arch" => Some(&info.arch),
"id" => Some(&info.id),
"addr" => Some(&info.addr),
_ => None,
})
.collect::<Vec<&str>>()
.join("");
hasher.update(fingerprint_string.as_bytes());
hasher.finalize()
}
File diff suppressed because it is too large Load Diff
+39
View File
@@ -0,0 +1,39 @@
use std::{fmt, slice::Iter, str::FromStr};
use crate::protos::message::KeyboardMode;
impl fmt::Display for KeyboardMode {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
KeyboardMode::Legacy => write!(f, "legacy"),
KeyboardMode::Map => write!(f, "map"),
KeyboardMode::Translate => write!(f, "translate"),
KeyboardMode::Auto => write!(f, "auto"),
}
}
}
impl FromStr for KeyboardMode {
type Err = ();
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"legacy" => Ok(KeyboardMode::Legacy),
"map" => Ok(KeyboardMode::Map),
"translate" => Ok(KeyboardMode::Translate),
"auto" => Ok(KeyboardMode::Auto),
_ => Err(()),
}
}
}
impl KeyboardMode {
pub fn iter() -> Iter<'static, KeyboardMode> {
static KEYBOARD_MODES: [KeyboardMode; 4] = [
KeyboardMode::Legacy,
KeyboardMode::Map,
KeyboardMode::Translate,
KeyboardMode::Auto,
];
KEYBOARD_MODES.iter()
}
}
+633
View File
@@ -0,0 +1,633 @@
pub mod compress;
pub mod platform;
pub mod protos;
pub use bytes;
use config::Config;
pub use futures;
pub use protobuf;
pub use protos::message as message_proto;
pub use protos::rendezvous as rendezvous_proto;
use serde_derive::{Deserialize, Serialize};
use std::{
fs::File,
io::{self, BufRead},
net::{IpAddr, Ipv4Addr, SocketAddr, SocketAddrV4},
path::Path,
time::{self, SystemTime, UNIX_EPOCH},
};
pub use tokio;
pub use tokio_util;
pub mod proxy;
pub mod socket_client;
pub mod tcp;
pub mod udp;
pub use env_logger;
pub use log;
pub mod bytes_codec;
pub use anyhow::{self, bail};
pub use futures_util;
pub mod config;
pub mod fs;
pub mod mem;
pub use lazy_static;
#[cfg(not(any(target_os = "android", target_os = "ios")))]
pub use mac_address;
pub use rand;
pub use regex;
pub use sodiumoxide;
pub use tokio_socks;
pub use tokio_socks::IntoTargetAddr;
pub use tokio_socks::TargetAddr;
pub mod password_security;
pub use chrono;
pub use directories_next;
pub use libc;
pub mod keyboard;
pub use base64;
#[cfg(not(any(target_os = "android", target_os = "ios")))]
pub use dlopen;
#[cfg(not(any(target_os = "android", target_os = "ios")))]
pub use machine_uid;
pub use serde_derive;
pub use serde_json;
pub use sha2;
pub use sysinfo;
pub use thiserror;
pub use toml;
pub use uuid;
pub mod fingerprint;
pub use flexi_logger;
pub mod stream;
pub mod websocket;
#[cfg(feature = "webrtc")]
pub mod webrtc;
#[cfg(any(target_os = "android", target_os = "ios"))]
pub use rustls_platform_verifier;
pub use stream::Stream;
pub use whoami;
pub mod tls;
pub mod verifier;
pub use async_recursion;
#[cfg(target_os = "linux")]
pub use users;
pub use libloading;
#[cfg(target_os = "linux")]
pub use x11;
pub type SessionID = uuid::Uuid;
#[inline]
pub async fn sleep(sec: f32) {
tokio::time::sleep(time::Duration::from_secs_f32(sec)).await;
}
#[macro_export]
macro_rules! allow_err {
($e:expr) => {
if let Err(err) = $e {
log::debug!(
"{:?}, {}:{}:{}:{}",
err,
module_path!(),
file!(),
line!(),
column!()
);
} else {
}
};
($e:expr, $($arg:tt)*) => {
if let Err(err) = $e {
log::debug!(
"{:?}, {}, {}:{}:{}:{}",
err,
format_args!($($arg)*),
module_path!(),
file!(),
line!(),
column!()
);
} else {
}
};
}
#[inline]
pub fn timeout<T: std::future::Future>(ms: u64, future: T) -> tokio::time::Timeout<T> {
tokio::time::timeout(std::time::Duration::from_millis(ms), future)
}
pub type ResultType<F, E = anyhow::Error> = anyhow::Result<F, E>;
/// Certain router and firewalls scan the packet and if they
/// find an IP address belonging to their pool that they use to do the NAT mapping/translation, so here we mangle the ip address
pub struct AddrMangle();
#[inline]
pub fn try_into_v4(addr: SocketAddr) -> SocketAddr {
match addr {
SocketAddr::V6(v6) if !addr.ip().is_loopback() => {
if let Some(v4) = v6.ip().to_ipv4() {
SocketAddr::new(IpAddr::V4(v4), addr.port())
} else {
addr
}
}
_ => addr,
}
}
impl AddrMangle {
pub fn encode(addr: SocketAddr) -> Vec<u8> {
// not work with [:1]:<port>
let addr = try_into_v4(addr);
match addr {
SocketAddr::V4(addr_v4) => {
let tm = (SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or(std::time::Duration::ZERO)
.as_micros() as u32) as u128;
let ip = u32::from_le_bytes(addr_v4.ip().octets()) as u128;
let port = addr.port() as u128;
let v = ((ip + tm) << 49) | (tm << 17) | (port + (tm & 0xFFFF));
let bytes = v.to_le_bytes();
let mut n_padding = 0;
for i in bytes.iter().rev() {
if i == &0u8 {
n_padding += 1;
} else {
break;
}
}
bytes[..(16 - n_padding)].to_vec()
}
SocketAddr::V6(addr_v6) => {
let mut x = addr_v6.ip().octets().to_vec();
let port: [u8; 2] = addr_v6.port().to_le_bytes();
x.push(port[0]);
x.push(port[1]);
x
}
}
}
pub fn decode(bytes: &[u8]) -> SocketAddr {
use std::convert::TryInto;
if bytes.len() > 16 {
if bytes.len() != 18 {
return Config::get_any_listen_addr(false);
}
let tmp: [u8; 2] = bytes[16..].try_into().unwrap_or_default();
let port = u16::from_le_bytes(tmp);
let tmp: [u8; 16] = bytes[..16].try_into().unwrap_or_default();
let ip = std::net::Ipv6Addr::from(tmp);
return SocketAddr::new(IpAddr::V6(ip), port);
}
let mut padded = [0u8; 16];
padded[..bytes.len()].copy_from_slice(bytes);
let number = u128::from_le_bytes(padded);
let tm = (number >> 17) & (u32::max_value() as u128);
let ip = (((number >> 49) - tm) as u32).to_le_bytes();
let port = (number & 0xFFFFFF) - (tm & 0xFFFF);
SocketAddr::V4(SocketAddrV4::new(
Ipv4Addr::new(ip[0], ip[1], ip[2], ip[3]),
port as u16,
))
}
}
pub fn get_version_from_url(url: &str) -> String {
let n = url.chars().count();
let a = url.chars().rev().position(|x| x == '-');
if let Some(a) = a {
let b = url.chars().rev().position(|x| x == '.');
if let Some(b) = b {
if a > b {
if url
.chars()
.skip(n - b)
.collect::<String>()
.parse::<i32>()
.is_ok()
{
return url.chars().skip(n - a).collect();
} else {
return url.chars().skip(n - a).take(a - b - 1).collect();
}
} else {
return url.chars().skip(n - a).collect();
}
}
}
"".to_owned()
}
pub fn gen_version() {
println!("cargo:rerun-if-changed=Cargo.toml");
use std::io::prelude::*;
let mut file = File::create("./src/version.rs").unwrap();
for line in read_lines("Cargo.toml").unwrap().flatten() {
let ab: Vec<&str> = line.split('=').map(|x| x.trim()).collect();
if ab.len() == 2 && ab[0] == "version" {
file.write_all(format!("pub const VERSION: &str = {};\n", ab[1]).as_bytes())
.ok();
break;
}
}
// generate build date
let build_date = format!("{}", chrono::Local::now().format("%Y-%m-%d %H:%M"));
file.write_all(
format!("#[allow(dead_code)]\npub const BUILD_DATE: &str = \"{build_date}\";\n").as_bytes(),
)
.ok();
file.sync_all().ok();
}
fn read_lines<P>(filename: P) -> io::Result<io::Lines<io::BufReader<File>>>
where
P: AsRef<Path>,
{
let file = File::open(filename)?;
Ok(io::BufReader::new(file).lines())
}
pub fn is_valid_custom_id(id: &str) -> bool {
regex::Regex::new(r"^[a-zA-Z][\w-]{5,15}$")
.unwrap()
.is_match(id)
}
// Support 1.1.10-1, the number after - is a patch version.
pub fn get_version_number(v: &str) -> i64 {
let mut versions = v.split('-');
let mut n = 0;
// The first part is the version number.
// 1.1.10 -> 1001100, 1.2.3 -> 1001030, multiple the last number by 10
// to leave space for patch version.
if let Some(v) = versions.next() {
let mut last = 0;
for x in v.split('.') {
last = x.parse::<i64>().unwrap_or(0);
n = n * 1000 + last;
}
n -= last;
n += last * 10;
}
if let Some(v) = versions.next() {
n += v.parse::<i64>().unwrap_or(0);
}
// Ignore the rest
n
}
pub fn get_modified_time(path: &std::path::Path) -> SystemTime {
std::fs::metadata(path)
.map(|m| m.modified().unwrap_or(UNIX_EPOCH))
.unwrap_or(UNIX_EPOCH)
}
pub fn get_created_time(path: &std::path::Path) -> SystemTime {
std::fs::metadata(path)
.map(|m| m.created().unwrap_or(UNIX_EPOCH))
.unwrap_or(UNIX_EPOCH)
}
pub fn get_exe_time() -> SystemTime {
std::env::current_exe().map_or(UNIX_EPOCH, |path| {
let m = get_modified_time(&path);
let c = get_created_time(&path);
if m > c {
m
} else {
c
}
})
}
/// Known cases where machine_uid::get() may fail:
/// - Windows shutdown: "The media is write protected. (os error 19)"
/// - macOS (hard to reproduce, reproduced at login screen): "No matching IOPlatformUUID in `ioreg -rd1 -c IOPlatformExpertDevice` command"
pub fn get_uuid() -> Vec<u8> {
#[cfg(not(any(target_os = "android", target_os = "ios")))]
{
use std::sync::atomic::{AtomicUsize, Ordering};
static CACHED_MACHINE_UID: std::sync::OnceLock<Vec<u8>> = std::sync::OnceLock::new();
// Throttle only applies to the fallback machine_uid::get() log below, not the Once::call_once retry logs.
static LOG_COUNT: AtomicUsize = AtomicUsize::new(0);
// Only macOS needs retry logic here because:
// - macOS: in testing, only one failure occurred when reading at 50ms intervals, so retry helps
// - Windows: failures during shutdown are persistent, retrying is pointless
#[cfg(target_os = "macos")]
{
static INIT: std::sync::Once = std::sync::Once::new();
INIT.call_once(|| {
// Keep in sync with upstream handling:
// https://github.com/rustdesk/rustdesk/blob/85db6779828349b23ca3eba91cc7cd36c5337797/src/common.rs#L822
let username = whoami::username().trim_end_matches('\0').to_owned();
let max_retries = if username == "root" { 16 } else { 8 };
for i in 0..max_retries {
match machine_uid::get() {
Ok(id) => {
let _ = CACHED_MACHINE_UID.set(id.into());
return;
}
Err(e) => {
log::error!("Failed to get machine uid in macOS retry #{i}: {e}");
}
}
std::thread::sleep(std::time::Duration::from_millis(50));
}
});
}
if let Some(uid) = CACHED_MACHINE_UID.get() {
return uid.clone();
}
match machine_uid::get() {
Ok(id) => {
let uid: Vec<u8> = id.into();
let _ = CACHED_MACHINE_UID.set(uid.clone());
return uid;
}
Err(e) => {
if LOG_COUNT
.fetch_update(Ordering::SeqCst, Ordering::SeqCst, |count| {
(count < 30).then_some(count + 1)
})
.is_ok()
{
log::error!("Failed to get machine uid: {e}");
}
}
}
}
Config::get_key_pair().1
}
#[inline]
pub fn get_time() -> i64 {
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_millis())
.unwrap_or(0) as _
}
#[inline]
pub fn is_ipv4_str(id: &str) -> bool {
if let Ok(reg) = regex::Regex::new(
r"^(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)(:\d+)?$",
) {
reg.is_match(id)
} else {
false
}
}
#[inline]
pub fn is_ipv6_str(id: &str) -> bool {
if let Ok(reg) = regex::Regex::new(
r"^((([a-fA-F0-9]{1,4}:{1,2})+[a-fA-F0-9]{1,4})|(\[([a-fA-F0-9]{1,4}:{1,2})+[a-fA-F0-9]{1,4}\]:\d+))$",
) {
reg.is_match(id)
} else {
false
}
}
#[inline]
pub fn is_ip_str(id: &str) -> bool {
is_ipv4_str(id) || is_ipv6_str(id)
}
#[inline]
pub fn is_domain_port_str(id: &str) -> bool {
// modified regex for RFC1123 hostname. check https://stackoverflow.com/a/106223 for original version for hostname.
// according to [TLD List](https://data.iana.org/TLD/tlds-alpha-by-domain.txt) version 2023011700,
// there is no digits in TLD, and length is 2~63.
if let Ok(reg) = regex::Regex::new(
r"(?i)^([a-z0-9]([a-z0-9-]{0,61}[a-z0-9])?\.)+[a-z][a-z-]{0,61}[a-z]:\d{1,5}$",
) {
reg.is_match(id)
} else {
false
}
}
pub fn init_log(_is_async: bool, _name: &str) -> Option<flexi_logger::LoggerHandle> {
static INIT: std::sync::Once = std::sync::Once::new();
#[allow(unused_mut)]
let mut logger_holder: Option<flexi_logger::LoggerHandle> = None;
INIT.call_once(|| {
#[cfg(debug_assertions)]
{
use env_logger::*;
init_from_env(Env::default().filter_or(DEFAULT_FILTER_ENV, "info,reqwest=warn,rustls=warn,webrtc-sctp=warn,webrtc=warn"));
}
#[cfg(not(debug_assertions))]
{
// https://docs.rs/flexi_logger/latest/flexi_logger/error_info/index.html#write
// though async logger more efficient, but it also causes more problems, disable it for now
let mut path = config::Config::log_path();
#[cfg(target_os = "android")]
if !config::Config::get_home().exists() {
return;
}
if !_name.is_empty() {
path.push(_name);
}
use flexi_logger::*;
if let Ok(x) = Logger::try_with_env_or_str("debug,reqwest=warn,rustls=warn,webrtc-sctp=warn,webrtc=warn") {
logger_holder = x
.log_to_file(FileSpec::default().directory(path))
.write_mode(if _is_async {
WriteMode::Async
} else {
WriteMode::Direct
})
.format(opt_format)
.rotate(
Criterion::Age(Age::Day),
Naming::Timestamps,
Cleanup::KeepLogFiles(31),
)
.start()
.ok();
}
}
});
logger_holder
}
#[derive(Debug, Default, Deserialize, Serialize)]
pub struct VersionCheckRequest {
#[serde(default)]
pub os: String,
#[serde(default)]
pub os_version: String,
#[serde(default)]
pub arch: String,
#[serde(default)]
pub device_id: Vec<u8>,
#[serde(default)]
pub typ: String,
}
#[derive(Debug, Default, Deserialize, Serialize)]
pub struct VersionCheckResponse {
#[serde(default)]
pub url: String,
}
pub const VER_TYPE_RUSTDESK_CLIENT: &str = "rustdesk-client";
pub const VER_TYPE_RUSTDESK_SERVER: &str = "rustdesk-server";
pub fn version_check_request(typ: String) -> (VersionCheckRequest, String) {
const URL: &str = "https://api.rustdesk.com/version/latest";
use sysinfo::System;
let system = System::new();
let os = system.distribution_id();
let os_version = system.os_version().unwrap_or_default();
let arch = std::env::consts::ARCH.to_string();
#[allow(deprecated)]
let device_id = fingerprint::get_fingerprint(None, None);
(
VersionCheckRequest {
os,
os_version,
arch,
device_id,
typ,
},
URL.to_string(),
)
}
pub fn time_based_rand() -> u32 {
let nanos = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_nanos();
let mut x = nanos as u64;
x ^= x << 13;
x ^= x >> 7;
x ^= x << 17;
(x % 32768) as u32
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn test_mangle() {
let addr = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(192, 168, 16, 32), 21116));
assert_eq!(addr, AddrMangle::decode(&AddrMangle::encode(addr)));
let addr = "[2001:db8::1]:8080".parse::<SocketAddr>().unwrap();
assert_eq!(addr, AddrMangle::decode(&AddrMangle::encode(addr)));
let addr = "[2001:db8:ff::1111]:80".parse::<SocketAddr>().unwrap();
assert_eq!(addr, AddrMangle::decode(&AddrMangle::encode(addr)));
}
#[test]
fn test_allow_err() {
allow_err!(Err("test err") as Result<(), &str>);
allow_err!(
Err("test err with msg") as Result<(), &str>,
"prompt {}",
"failed"
);
}
#[test]
fn test_ipv6() {
assert!(is_ipv6_str("1:2:3"));
assert!(is_ipv6_str("[ab:2:3]:12"));
assert!(is_ipv6_str("[ABEF:2a:3]:12"));
assert!(!is_ipv6_str("[ABEG:2a:3]:12"));
assert!(!is_ipv6_str("1[ab:2:3]:12"));
assert!(!is_ipv6_str("1.1.1.1"));
assert!(is_ip_str("1.1.1.1"));
assert!(!is_ipv6_str("1:2:"));
assert!(is_ipv6_str("1:2::0"));
assert!(is_ipv6_str("[1:2::0]:1"));
assert!(!is_ipv6_str("[1:2::0]:"));
assert!(!is_ipv6_str("1:2::0]:1"));
}
#[test]
fn test_ipv4() {
assert!(is_ipv4_str("1.2.3.4"));
assert!(is_ipv4_str("1.2.3.4:90"));
assert!(is_ipv4_str("192.168.0.1"));
assert!(is_ipv4_str("0.0.0.0"));
assert!(is_ipv4_str("255.255.255.255"));
assert!(!is_ipv4_str("256.0.0.0"));
assert!(!is_ipv4_str("256.256.256.256"));
assert!(!is_ipv4_str("1:2:"));
assert!(!is_ipv4_str("192.168.0.256"));
assert!(!is_ipv4_str("192.168.0.1/24"));
assert!(!is_ipv4_str("192.168.0."));
assert!(!is_ipv4_str("192.168..1"));
}
#[test]
fn test_hostname_port() {
assert!(!is_domain_port_str("a:12"));
assert!(!is_domain_port_str("a.b.c:12"));
assert!(is_domain_port_str("test.com:12"));
assert!(is_domain_port_str("test-UPPER.com:12"));
assert!(is_domain_port_str("some-other.domain.com:12"));
assert!(!is_domain_port_str("under_score:12"));
assert!(!is_domain_port_str("a@bc:12"));
assert!(!is_domain_port_str("1.1.1.1:12"));
assert!(!is_domain_port_str("1.2.3:12"));
assert!(!is_domain_port_str("1.2.3.45:12"));
assert!(!is_domain_port_str("a.b.c:123456"));
assert!(!is_domain_port_str("---:12"));
assert!(!is_domain_port_str(".:12"));
// todo: should we also check for these edge cases?
// out-of-range port
assert!(is_domain_port_str("test.com:0"));
assert!(is_domain_port_str("test.com:98989"));
}
#[test]
fn test_mangle2() {
let addr = "[::ffff:127.0.0.1]:8080".parse().unwrap();
let addr_v4 = "127.0.0.1:8080".parse().unwrap();
assert_eq!(AddrMangle::decode(&AddrMangle::encode(addr)), addr_v4);
assert_eq!(
AddrMangle::decode(&AddrMangle::encode("[::127.0.0.1]:8080".parse().unwrap())),
addr_v4
);
assert_eq!(AddrMangle::decode(&AddrMangle::encode(addr_v4)), addr_v4);
let addr_v6 = "[ef::fe]:8080".parse().unwrap();
assert_eq!(AddrMangle::decode(&AddrMangle::encode(addr_v6)), addr_v6);
let addr_v6 = "[::1]:8080".parse().unwrap();
assert_eq!(AddrMangle::decode(&AddrMangle::encode(addr_v6)), addr_v6);
}
#[test]
fn test_get_version_number() {
assert_eq!(get_version_number("1.1.10"), 1001100);
assert_eq!(get_version_number("1.1.10-1"), 1001101);
assert_eq!(get_version_number("1.1.11-1"), 1001111);
assert_eq!(get_version_number("1.2.3"), 1002030);
}
}
+14
View File
@@ -0,0 +1,14 @@
/// SAFETY: the returned Vec must not be resized or reserverd
pub unsafe fn aligned_u8_vec(cap: usize, align: usize) -> Vec<u8> {
use std::alloc::*;
let layout =
Layout::from_size_align(cap, align).expect("invalid aligned value, must be power of 2");
unsafe {
let ptr = alloc(layout);
if ptr.is_null() {
panic!("failed to allocate {} bytes", cap);
}
Vec::from_raw_parts(ptr, 0, cap)
}
}
+474
View File
@@ -0,0 +1,474 @@
use crate::config::Config;
use sodiumoxide::base64;
use std::sync::{Arc, RwLock};
lazy_static::lazy_static! {
pub static ref TEMPORARY_PASSWORD:Arc<RwLock<String>> = Arc::new(RwLock::new(get_auto_password()));
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum VerificationMethod {
OnlyUseTemporaryPassword,
OnlyUsePermanentPassword,
UseBothPasswords,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ApproveMode {
Both,
Password,
Click,
}
fn get_auto_password() -> String {
let len = temporary_password_length();
if Config::get_bool_option(crate::config::keys::OPTION_ALLOW_NUMERNIC_ONE_TIME_PASSWORD) {
Config::get_auto_numeric_password(len)
} else {
Config::get_auto_password(len)
}
}
// Should only be called in server
pub fn update_temporary_password() {
*TEMPORARY_PASSWORD.write().unwrap() = get_auto_password();
}
// Should only be called in server
pub fn temporary_password() -> String {
TEMPORARY_PASSWORD.read().unwrap().clone()
}
fn verification_method() -> VerificationMethod {
let method = Config::get_option("verification-method");
if method == "use-temporary-password" {
VerificationMethod::OnlyUseTemporaryPassword
} else if method == "use-permanent-password" {
VerificationMethod::OnlyUsePermanentPassword
} else {
VerificationMethod::UseBothPasswords // default
}
}
pub fn temporary_password_length() -> usize {
let length = Config::get_option("temporary-password-length");
if length == "8" {
8
} else if length == "10" {
10
} else {
6 // default
}
}
pub fn temporary_enabled() -> bool {
verification_method() != VerificationMethod::OnlyUsePermanentPassword
}
pub fn permanent_enabled() -> bool {
verification_method() != VerificationMethod::OnlyUseTemporaryPassword
}
pub fn has_valid_password() -> bool {
temporary_enabled() && !temporary_password().is_empty()
|| permanent_enabled() && Config::has_permanent_password()
}
pub fn approve_mode() -> ApproveMode {
let mode = Config::get_option("approve-mode");
if mode == "password" {
ApproveMode::Password
} else if mode == "click" {
ApproveMode::Click
} else {
ApproveMode::Both
}
}
pub fn hide_cm() -> bool {
approve_mode() == ApproveMode::Password
&& verification_method() == VerificationMethod::OnlyUsePermanentPassword
&& crate::config::option2bool("allow-hide-cm", &Config::get_option("allow-hide-cm"))
}
const VERSION_LEN: usize = 2;
// Check if data is already encrypted by verifying:
// 1) version prefix "00"
// 2) valid base64 payload
// 3) decoded payload length >= secretbox::MACBYTES
//
// We intentionally avoid trying to decrypt here because key mismatch would cause
// false negatives.
// Reference: secretbox::seal returns ciphertext length = plaintext length + MACBYTES
// https://github.com/sodiumoxide/sodiumoxide/blob/3057acb1a030ad86ed8892a223d64036ab5e8523/src/crypto/secretbox/xsalsa20poly1305.rs#L67
fn is_encrypted(v: &[u8]) -> bool {
if v.len() <= VERSION_LEN || !v.starts_with(b"00") {
return false;
}
match base64::decode(&v[VERSION_LEN..], base64::Variant::Original) {
Ok(decoded) => decoded.len() >= sodiumoxide::crypto::secretbox::MACBYTES,
Err(_) => false,
}
}
pub fn encrypt_str_or_original(s: &str, version: &str, max_len: usize) -> String {
if is_encrypted(s.as_bytes()) {
log::error!("Duplicate encryption!");
return s.to_owned();
}
if s.chars().count() > max_len {
return String::default();
}
if version == "00" {
if let Ok(s) = encrypt(s.as_bytes()) {
return version.to_owned() + &s;
}
}
s.to_owned()
}
// String: password
// bool: whether decryption is successful
// bool: whether should store to re-encrypt when load
// note: s.len() return length in bytes, s.chars().count() return char count
// &[..2] return the left 2 bytes, s.chars().take(2) return the left 2 chars
pub fn decrypt_str_or_original(s: &str, current_version: &str) -> (String, bool, bool) {
if s.len() > VERSION_LEN {
if s.starts_with("00") {
if let Ok(v) = decrypt(s[VERSION_LEN..].as_bytes()) {
return (
String::from_utf8_lossy(&v).to_string(),
true,
"00" != current_version,
);
}
}
}
// For values that already look encrypted (version prefix + base64), avoid
// repeated store on each load when decryption fails.
(
s.to_owned(),
false,
!s.is_empty() && !is_encrypted(s.as_bytes()),
)
}
pub fn encrypt_vec_or_original(v: &[u8], version: &str, max_len: usize) -> Vec<u8> {
if is_encrypted(v) {
log::error!("Duplicate encryption!");
return v.to_owned();
}
if v.len() > max_len {
return vec![];
}
if version == "00" {
if let Ok(s) = encrypt(v) {
let mut version = version.to_owned().into_bytes();
version.append(&mut s.into_bytes());
return version;
}
}
v.to_owned()
}
// Vec<u8>: password
// bool: whether decryption is successful
// bool: whether should store to re-encrypt when load
pub fn decrypt_vec_or_original(v: &[u8], current_version: &str) -> (Vec<u8>, bool, bool) {
if v.len() > VERSION_LEN {
let version = String::from_utf8_lossy(&v[..VERSION_LEN]);
if version == "00" {
if let Ok(v) = decrypt(&v[VERSION_LEN..]) {
return (v, true, version != current_version);
}
}
}
// For values that already look encrypted (version prefix + base64), avoid
// repeated store on each load when decryption fails.
(v.to_owned(), false, !v.is_empty() && !is_encrypted(v))
}
fn encrypt(v: &[u8]) -> Result<String, ()> {
if !v.is_empty() {
symmetric_crypt(v, true).map(|v| base64::encode(v, base64::Variant::Original))
} else {
Err(())
}
}
fn decrypt(v: &[u8]) -> Result<Vec<u8>, ()> {
if !v.is_empty() {
base64::decode(v, base64::Variant::Original).and_then(|v| symmetric_crypt(&v, false))
} else {
Err(())
}
}
pub fn symmetric_crypt(data: &[u8], encrypt: bool) -> Result<Vec<u8>, ()> {
use sodiumoxide::crypto::secretbox;
use std::convert::TryInto;
let uuid = crate::get_uuid();
let mut keybuf = uuid.clone();
keybuf.resize(secretbox::KEYBYTES, 0);
let key = secretbox::Key(keybuf.try_into().map_err(|_| ())?);
let nonce = secretbox::Nonce([0; secretbox::NONCEBYTES]);
if encrypt {
Ok(secretbox::seal(data, &nonce, &key))
} else {
let res = secretbox::open(data, &nonce, &key);
#[cfg(not(any(target_os = "android", target_os = "ios")))]
if res.is_err() {
// Fallback: try pk if uuid decryption failed (in case encryption used pk due to machine_uid failure)
if let Some(key_pair) = Config::get_existing_key_pair() {
let pk = key_pair.1;
if pk != uuid {
let mut keybuf = pk;
keybuf.resize(secretbox::KEYBYTES, 0);
let pk_key = secretbox::Key(keybuf.try_into().map_err(|_| ())?);
return secretbox::open(data, &nonce, &pk_key);
}
}
}
res
}
}
mod test {
#[test]
fn test() {
use super::*;
use rand::{thread_rng, Rng};
use std::time::Instant;
let version = "00";
let max_len = 128;
println!("test str");
let data = "1ü1111";
let encrypted = encrypt_str_or_original(data, version, max_len);
let (decrypted, succ, store) = decrypt_str_or_original(&encrypted, version);
println!("data: {data}");
println!("encrypted: {encrypted}");
println!("decrypted: {decrypted}");
assert_eq!(data, decrypted);
assert_eq!(version, &encrypted[..2]);
assert!(succ);
assert!(!store);
let (_, _, store) = decrypt_str_or_original(&encrypted, "99");
assert!(store);
assert!(!decrypt_str_or_original(&decrypted, version).1);
assert_eq!(
encrypt_str_or_original(&encrypted, version, max_len),
encrypted
);
println!("test vec");
let data: Vec<u8> = "1ü1111".as_bytes().to_vec();
let encrypted = encrypt_vec_or_original(&data, version, max_len);
let (decrypted, succ, store) = decrypt_vec_or_original(&encrypted, version);
println!("data: {data:?}");
println!("encrypted: {encrypted:?}");
println!("decrypted: {decrypted:?}");
assert_eq!(data, decrypted);
assert_eq!(version.as_bytes(), &encrypted[..2]);
assert!(!store);
assert!(succ);
let (_, _, store) = decrypt_vec_or_original(&encrypted, "99");
assert!(store);
assert!(!decrypt_vec_or_original(&decrypted, version).1);
assert_eq!(
encrypt_vec_or_original(&encrypted, version, max_len),
encrypted
);
println!("test original");
let data = version.to_string() + "Hello World";
let (decrypted, succ, store) = decrypt_str_or_original(&data, version);
assert_eq!(data, decrypted);
assert!(store);
assert!(!succ);
let verbytes = version.as_bytes();
let data: Vec<u8> = vec![verbytes[0], verbytes[1], 1, 2, 3, 4, 5, 6];
let (decrypted, succ, store) = decrypt_vec_or_original(&data, version);
assert_eq!(data, decrypted);
assert!(store);
assert!(!succ);
let (_, succ, store) = decrypt_str_or_original("", version);
assert!(!store);
assert!(!succ);
let (_, succ, store) = decrypt_vec_or_original(&[], version);
assert!(!store);
assert!(!succ);
let data = "1ü1111";
assert_eq!(decrypt_str_or_original(data, version).0, data);
let data: Vec<u8> = "1ü1111".as_bytes().to_vec();
assert_eq!(decrypt_vec_or_original(&data, version).0, data);
// Base64-shaped "00" prefixed values shorter than MACBYTES are treated
// as original/plain values and should be stored.
let data = "00YWJjZA==";
let (decrypted, succ, store) = decrypt_str_or_original(data, version);
assert_eq!(decrypted, data);
assert!(!succ);
assert!(store);
let data = b"00YWJjZA==".to_vec();
let (decrypted, succ, store) = decrypt_vec_or_original(&data, version);
assert_eq!(decrypted, data);
assert!(!succ);
assert!(store);
// When decoded length reaches MACBYTES, it is treated as encrypted-like
// and should not trigger repeated store.
let exact_mac = vec![0u8; sodiumoxide::crypto::secretbox::MACBYTES];
let exact_mac_b64 =
sodiumoxide::base64::encode(&exact_mac, sodiumoxide::base64::Variant::Original);
let data = format!("00{exact_mac_b64}");
let (_, succ, store) = decrypt_str_or_original(&data, version);
assert!(!succ);
assert!(!store);
let data = data.into_bytes();
let (_, succ, store) = decrypt_vec_or_original(&data, version);
assert!(!succ);
assert!(!store);
println!("test speed");
let test_speed = |len: usize, name: &str| {
let mut data: Vec<u8> = vec![];
let mut rng = thread_rng();
for _ in 0..len {
data.push(rng.gen_range(0..255));
}
let start: Instant = Instant::now();
let encrypted = encrypt_vec_or_original(&data, version, len);
assert_ne!(data, decrypted);
let t1 = start.elapsed();
let start = Instant::now();
let (decrypted, _, _) = decrypt_vec_or_original(&encrypted, version);
let t2 = start.elapsed();
assert_eq!(data, decrypted);
println!("{name}");
println!("encrypt:{:?}, decrypt:{:?}", t1, t2);
let start: Instant = Instant::now();
let encrypted = base64::encode(&data, base64::Variant::Original);
let t1 = start.elapsed();
let start = Instant::now();
let decrypted = base64::decode(&encrypted, base64::Variant::Original).unwrap();
let t2 = start.elapsed();
assert_eq!(data, decrypted);
println!("base64, encrypt:{:?}, decrypt:{:?}", t1, t2,);
};
test_speed(128, "128");
test_speed(1024, "1k");
test_speed(1024 * 1024, "1M");
test_speed(10 * 1024 * 1024, "10M");
test_speed(100 * 1024 * 1024, "100M");
}
#[test]
fn test_is_encrypted() {
use super::*;
use sodiumoxide::base64::{encode, Variant};
use sodiumoxide::crypto::secretbox;
// Empty data should not be considered encrypted
assert!(!is_encrypted(b""));
assert!(!is_encrypted(b"0"));
assert!(!is_encrypted(b"00"));
// Data without "00" prefix should not be considered encrypted
assert!(!is_encrypted(b"01abcd"));
assert!(!is_encrypted(b"99abcd"));
assert!(!is_encrypted(b"hello world"));
// Data with "00" prefix but invalid base64 should not be considered encrypted
assert!(!is_encrypted(b"00!!!invalid base64!!!"));
assert!(!is_encrypted(b"00@#$%"));
// Data with "00" prefix and valid base64 but shorter than MACBYTES is not encrypted
assert!(!is_encrypted(b"00YWJjZA==")); // "abcd" in base64
assert!(!is_encrypted(b"00SGVsbG8gV29ybGQ=")); // "Hello World" in base64
// Data with "00" prefix and valid base64 with decoded len == MACBYTES is considered encrypted
let exact_mac = vec![0u8; secretbox::MACBYTES];
let exact_mac_b64 = encode(&exact_mac, Variant::Original);
let exact_mac_candidate = format!("00{exact_mac_b64}");
assert!(is_encrypted(exact_mac_candidate.as_bytes()));
// Real encrypted data should be detected
let version = "00";
let max_len = 128;
let encrypted_str = encrypt_str_or_original("1", version, max_len);
assert!(is_encrypted(encrypted_str.as_bytes()));
let encrypted_vec = encrypt_vec_or_original(b"1", version, max_len);
assert!(is_encrypted(&encrypted_vec));
// Original unencrypted data should not be detected as encrypted
assert!(!is_encrypted(b"1"));
assert!(!is_encrypted("1".as_bytes()));
}
#[test]
fn test_encrypted_payload_min_len_macbytes() {
use super::*;
use sodiumoxide::base64::{decode, Variant};
use sodiumoxide::crypto::secretbox;
let version = "00";
let max_len = 128;
let encrypted_str = encrypt_str_or_original("1", version, max_len);
let decoded = decode(&encrypted_str.as_bytes()[VERSION_LEN..], Variant::Original).unwrap();
assert!(
decoded.len() >= secretbox::MACBYTES,
"decoded encrypted payload must be at least MACBYTES"
);
let encrypted_vec = encrypt_vec_or_original(b"1", version, max_len);
let decoded = decode(&encrypted_vec[VERSION_LEN..], Variant::Original).unwrap();
assert!(
decoded.len() >= secretbox::MACBYTES,
"decoded encrypted payload must be at least MACBYTES"
);
}
// Test decryption fallback when data was encrypted with key_pair but decryption tries machine_uid first
#[test]
#[cfg(not(any(target_os = "android", target_os = "ios")))]
fn test_decrypt_with_pk_fallback() {
use sodiumoxide::crypto::secretbox;
use std::convert::TryInto;
let uuid = crate::get_uuid();
let pk = crate::config::Config::get_key_pair().1;
// Ensure uuid != pk, otherwise fallback branch won't be tested
if uuid == pk {
eprintln!("skip: uuid == pk, fallback branch won't be tested");
return;
}
let data = b"test password 123";
let nonce = secretbox::Nonce([0; secretbox::NONCEBYTES]);
// Encrypt with pk (simulating machine_uid failure during encryption)
let mut pk_keybuf = pk;
pk_keybuf.resize(secretbox::KEYBYTES, 0);
let pk_key = secretbox::Key(pk_keybuf.try_into().unwrap());
let encrypted = secretbox::seal(data, &nonce, &pk_key);
// Decrypt using symmetric_crypt (should fallback to pk since uuid differs)
let decrypted = super::symmetric_crypt(&encrypted, false);
assert!(
decrypted.is_ok(),
"Decryption with pk fallback should succeed"
);
assert_eq!(decrypted.unwrap(), data);
}
}
+572
View File
@@ -0,0 +1,572 @@
use crate::ResultType;
use std::{
collections::HashMap,
path::{Path, PathBuf},
process::Command,
};
use users::{get_current_uid, get_user_by_uid, os::unix::UserExt};
use sctk::{
output::OutputData,
output::{OutputHandler, OutputState},
reexports::client::protocol::wl_output::WlOutput,
reexports::client::{globals, Proxy},
reexports::client::{Connection, QueueHandle},
registry::{ProvidesRegistryState, RegistryState},
};
lazy_static::lazy_static! {
pub static ref DISTRO: Distro = Distro::new();
}
// to-do: There seems to be some runtime issue that causes the audit logs to be generated.
// We may need to fix this and remove this workaround in the future.
//
// We use the pre-search method to find the command path to avoid the audit logs on some systems.
// No idea why the audit logs happen.
// Though the audit logs may disappear after rebooting.
//
// See https://github.com/rustdesk/rustdesk/discussions/11959
//
// `ausearch -x /usr/share/rustdesk/rustdesk` will return
// ...
// time->Tue Jun 24 10:40:43 2025
// type=PROCTITLE msg=audit(1750776043.446:192757): proctitle=2F7573722F62696E2F727573746465736B002D2D73657276696365
// type=PATH msg=audit(1750776043.446:192757): item=0 name="/usr/local/bin/sh" nametype=UNKNOWN cap_fp=0 cap_fi=0 cap_fe=0 cap_fver=0 cap_frootid=0
// type=CWD msg=audit(1750776043.446:192757): cwd="/"
// type=SYSCALL msg=audit(1750776043.446:192757): arch=c000003e syscall=59 success=no exit=-2 a0=7fb7dbd22da0 a1=1d65f2c0 a2=7ffc25193360 a3=7ffc25194ec0 items=1 ppid=172208 pid=267565 auid=4294967295 uid=0 gid=0 euid=0 suid=0 fsuid=0 egid=0 sgid=0 fsgid=0 tty=(none) ses=4294967295 comm="rustdesk" exe="/usr/share/rustdesk/rustdesk" subj=unconfined key="processos_criados"
// ----
// time->Tue Jun 24 10:40:43 2025
// type=PROCTITLE msg=audit(1750776043.446:192758): proctitle=2F7573722F62696E2F727573746465736B002D2D73657276696365
// type=PATH msg=audit(1750776043.446:192758): item=0 name="/usr/sbin/sh" nametype=UNKNOWN cap_fp=0 cap_fi=0 cap_fe=0 cap_fver=0 cap_frootid=0
// ...
lazy_static::lazy_static! {
pub static ref CMD_LOGINCTL: String = find_cmd_path("loginctl");
pub static ref CMD_PS: String = find_cmd_path("ps");
pub static ref CMD_SH: String = find_cmd_path("sh");
}
pub const DISPLAY_SERVER_WAYLAND: &str = "wayland";
pub const DISPLAY_SERVER_X11: &str = "x11";
pub const DISPLAY_DESKTOP_KDE: &str = "KDE";
pub const XDG_CURRENT_DESKTOP: &str = "XDG_CURRENT_DESKTOP";
pub struct Distro {
pub name: String,
pub version_id: String,
}
impl Distro {
fn new() -> Self {
let name = run_cmds("awk -F'=' '/^NAME=/ {print $2}' /etc/os-release")
.unwrap_or_default()
.trim()
.trim_matches('"')
.to_string();
let version_id = run_cmds("awk -F'=' '/^VERSION_ID=/ {print $2}' /etc/os-release")
.unwrap_or_default()
.trim()
.trim_matches('"')
.to_string();
Self { name, version_id }
}
}
fn find_cmd_path(cmd: &'static str) -> String {
let test_cmd = format!("/bin/{}", cmd);
if std::path::Path::new(&test_cmd).exists() {
return test_cmd;
}
let test_cmd = format!("/usr/bin/{}", cmd);
if std::path::Path::new(&test_cmd).exists() {
return test_cmd;
}
if let Ok(output) = Command::new("which").arg(cmd).output() {
if output.status.success() {
return String::from_utf8_lossy(&output.stdout).trim().to_string();
}
}
cmd.to_string()
}
// Deprecated. Use `hbb_common::platform::linux::is_kde_session()` instead for now.
// Or we need to set the correct environment variable in the server process.
#[inline]
pub fn is_kde() -> bool {
if let Ok(env) = std::env::var(XDG_CURRENT_DESKTOP) {
env == DISPLAY_DESKTOP_KDE
} else {
false
}
}
// Don't use `hbb_common::platform::linux::is_kde()` here.
// It's not correct in the server process.
pub fn is_kde_session() -> bool {
std::process::Command::new(CMD_SH.as_str())
.arg("-c")
.arg("pgrep -f kded[0-9]+")
.stdout(std::process::Stdio::piped())
.output()
.map(|o| !o.stdout.is_empty())
.unwrap_or(false)
}
#[inline]
pub fn is_gdm_user(username: &str) -> bool {
username == "gdm" || username == "sddm"
// || username == "lightgdm"
}
#[inline]
pub fn is_desktop_wayland() -> bool {
get_display_server() == DISPLAY_SERVER_WAYLAND
}
#[inline]
pub fn is_x11_or_headless() -> bool {
!is_desktop_wayland()
}
// -1
const INVALID_SESSION: &str = "4294967295";
pub fn get_display_server() -> String {
// Check for forced display server environment variable first
if let Ok(forced_display) = std::env::var("RUSTDESK_FORCED_DISPLAY_SERVER") {
return forced_display;
}
// Check if `loginctl` can be called successfully
if run_loginctl(None).is_err() {
return DISPLAY_SERVER_X11.to_owned();
}
let mut session = get_values_of_seat0(&[0])[0].clone();
if session.is_empty() {
// loginctl has not given the expected output. try something else.
if let Ok(sid) = std::env::var("XDG_SESSION_ID") {
// could also execute "cat /proc/self/sessionid"
session = sid;
}
if session.is_empty() {
session = run_cmds("cat /proc/self/sessionid").unwrap_or_default();
if session == INVALID_SESSION {
session = "".to_owned();
}
}
}
if session.is_empty() {
std::env::var("XDG_SESSION_TYPE").unwrap_or("x11".to_owned())
} else {
get_display_server_of_session(&session)
}
}
pub fn get_display_server_of_session(session: &str) -> String {
let mut display_server = if let Ok(output) =
run_loginctl(Some(vec!["show-session", "-p", "Type", session]))
// Check session type of the session
{
String::from_utf8_lossy(&output.stdout)
.replace("Type=", "")
.trim_end()
.into()
} else {
"".to_owned()
};
if display_server.is_empty() || display_server == "tty" || display_server == "unspecified" {
if let Ok(sestype) = std::env::var("XDG_SESSION_TYPE") {
if !sestype.is_empty() {
return sestype.to_lowercase();
}
}
display_server = "x11".to_owned();
}
display_server.to_lowercase()
}
#[inline]
fn line_values(indices: &[usize], line: &str) -> Vec<String> {
indices
.into_iter()
.map(|idx| line.split_whitespace().nth(*idx).unwrap_or("").to_owned())
.collect::<Vec<String>>()
}
#[inline]
pub fn get_values_of_seat0(indices: &[usize]) -> Vec<String> {
_get_values_of_seat0(indices, true)
}
#[inline]
pub fn get_values_of_seat0_with_gdm_wayland(indices: &[usize]) -> Vec<String> {
_get_values_of_seat0(indices, false)
}
// Ignore "3 sessions listed."
fn ignore_loginctl_line(line: &str) -> bool {
line.contains("sessions") || line.split(" ").count() < 4
}
fn _get_values_of_seat0(indices: &[usize], ignore_gdm_wayland: bool) -> Vec<String> {
if let Ok(output) = run_loginctl(None) {
for line in String::from_utf8_lossy(&output.stdout).lines() {
if ignore_loginctl_line(line) {
continue;
}
if line.contains("seat0") {
if let Some(sid) = line.split_whitespace().next() {
if is_active(sid) {
if ignore_gdm_wayland {
if is_gdm_user(line.split_whitespace().nth(2).unwrap_or(""))
&& get_display_server_of_session(sid) == DISPLAY_SERVER_WAYLAND
{
continue;
}
}
return line_values(indices, line);
}
}
}
}
// some case, there is no seat0 https://github.com/rustdesk/rustdesk/issues/73
for line in String::from_utf8_lossy(&output.stdout).lines() {
if ignore_loginctl_line(line) {
continue;
}
if let Some(sid) = line.split_whitespace().next() {
if is_active(sid) {
let d = get_display_server_of_session(sid);
if ignore_gdm_wayland {
if is_gdm_user(line.split_whitespace().nth(2).unwrap_or(""))
&& d == DISPLAY_SERVER_WAYLAND
{
continue;
}
}
if d == "tty" || d == "unspecified" {
continue;
}
return line_values(indices, line);
}
}
}
}
line_values(indices, "")
}
pub fn is_active(sid: &str) -> bool {
if let Ok(output) = run_loginctl(Some(vec!["show-session", "-p", "State", sid])) {
String::from_utf8_lossy(&output.stdout).contains("active")
} else {
false
}
}
pub fn is_active_and_seat0(sid: &str) -> bool {
if let Ok(output) = run_loginctl(Some(vec!["show-session", sid])) {
String::from_utf8_lossy(&output.stdout).contains("State=active")
&& String::from_utf8_lossy(&output.stdout).contains("Seat=seat0")
} else {
false
}
}
// Check both "Lock" and "Switch user"
pub fn is_session_locked(sid: &str) -> bool {
if let Ok(output) = run_loginctl(Some(vec!["show-session", sid, "--property=LockedHint"])) {
String::from_utf8_lossy(&output.stdout).contains("LockedHint=yes")
} else {
false
}
}
// **Note** that the return value here, the last character is '\n'.
// Use `run_cmds_trim_newline()` if you want to remove '\n' at the end.
pub fn run_cmds(cmds: &str) -> ResultType<String> {
let output = std::process::Command::new(CMD_SH.as_str())
.args(vec!["-c", cmds])
.output()?;
Ok(String::from_utf8_lossy(&output.stdout).to_string())
}
pub fn run_cmds_trim_newline(cmds: &str) -> ResultType<String> {
let output = std::process::Command::new(CMD_SH.as_str())
.args(vec!["-c", cmds])
.output()?;
let out = String::from_utf8_lossy(&output.stdout);
Ok(if out.ends_with('\n') {
out[..out.len() - 1].to_string()
} else {
out.to_string()
})
}
fn run_loginctl(args: Option<Vec<&str>>) -> std::io::Result<std::process::Output> {
if std::env::var("FLATPAK_ID").is_ok() {
let mut l_args = CMD_LOGINCTL.to_string();
if let Some(a) = args.as_ref() {
l_args = format!("{} {}", l_args, a.join(" "));
}
let res = std::process::Command::new("flatpak-spawn")
.args(vec![String::from("--host"), l_args])
.output();
if res.is_ok() {
return res;
}
}
let mut cmd = std::process::Command::new(CMD_LOGINCTL.as_str());
if let Some(a) = args {
return cmd.args(a).output();
}
cmd.output()
}
/// forever: may not work
#[cfg(target_os = "linux")]
pub fn system_message(title: &str, msg: &str, forever: bool) -> ResultType<()> {
let cmds: HashMap<&str, Vec<&str>> = HashMap::from([
("notify-send", [title, msg].to_vec()),
(
"zenity",
[
"--info",
"--timeout",
if forever { "0" } else { "3" },
"--title",
title,
"--text",
msg,
]
.to_vec(),
),
("kdialog", ["--title", title, "--msgbox", msg].to_vec()),
(
"xmessage",
[
"-center",
"-timeout",
if forever { "0" } else { "3" },
title,
msg,
]
.to_vec(),
),
]);
for (k, v) in cmds {
if Command::new(k).args(v).spawn().is_ok() {
return Ok(());
}
}
crate::bail!("failed to post system message");
}
#[derive(Debug, Clone)]
pub struct WaylandDisplayInfo {
pub name: String,
pub x: i32,
pub y: i32,
pub width: i32,
pub height: i32,
pub logical_size: Option<(i32, i32)>,
pub refresh_rate: i32,
}
// Retrieves information about all connected displays via the Wayland protocol.
pub fn get_wayland_displays() -> ResultType<Vec<WaylandDisplayInfo>> {
struct WaylandEnv {
registry_state: RegistryState,
output_state: OutputState,
}
impl OutputHandler for WaylandEnv {
fn output_state(&mut self) -> &mut OutputState {
&mut self.output_state
}
fn new_output(&mut self, _: &Connection, _: &QueueHandle<Self>, _: WlOutput) {}
fn update_output(&mut self, _: &Connection, _: &QueueHandle<Self>, _: WlOutput) {}
fn output_destroyed(&mut self, _: &Connection, _: &QueueHandle<Self>, _: WlOutput) {}
}
impl ProvidesRegistryState for WaylandEnv {
fn registry(&mut self) -> &mut RegistryState {
&mut self.registry_state
}
sctk::registry_handlers!();
}
sctk::delegate_output!(WaylandEnv);
sctk::delegate_registry!(WaylandEnv);
let conn = Connection::connect_to_env()?;
let (globals, mut event_queue) = globals::registry_queue_init(&conn)?;
let queue_handle = event_queue.handle();
let registry_state = RegistryState::new(&globals);
let output_state = OutputState::new(&globals, &queue_handle);
let mut environment = WaylandEnv {
registry_state,
output_state,
};
event_queue.roundtrip(&mut environment)?;
let outputs: Vec<_> = environment.output_state.outputs().collect();
let mut display_infos = Vec::new();
for output in outputs {
if let Some(output_data) = output.data::<OutputData>() {
output_data.with_output_info(|info| {
if let Some(mode) = info.modes.iter().find(|m| m.current) {
let (x, y) = info.location;
let (width, height) = mode.dimensions;
let refresh_rate = mode.refresh_rate;
let name = info.name.clone().unwrap_or_default();
let logical_size = info.logical_size;
display_infos.push(WaylandDisplayInfo {
name,
x,
y,
width,
height,
logical_size,
refresh_rate,
});
}
});
}
}
Ok(display_infos)
}
/// Escape a string for safe use in shell commands by wrapping in single quotes.
///
/// This function handles the edge case of single quotes within the string by:
/// 1. Ending the current single-quoted section
/// 2. Adding an escaped single quote
/// 3. Starting a new single-quoted section
///
/// Example: "it's here" -> "'it'\''s here'"
#[inline]
pub fn shell_quote(s: &str) -> String {
format!("'{}'", s.replace("'", "'\\''"))
}
/// Get the current user's home directory via getpwuid (trusted source).
///
/// This function uses the system's password database (via `getpwuid`) to retrieve
/// the home directory, avoiding the security risk of relying on the `HOME`
/// environment variable which can be manipulated by untrusted input.
///
/// # Returns
/// - `Some(PathBuf)` if the home directory was found and exists
/// - `None` if the user lookup failed or the directory doesn't exist
///
/// # Security
/// This function is designed to be safe against confused-deputy attacks where
/// an attacker might manipulate environment variables to influence privileged
/// operations.
pub fn get_home_dir_trusted() -> Option<PathBuf> {
let uid = get_current_uid();
match get_user_by_uid(uid) {
Some(user) => {
let home = user.home_dir();
if Path::is_dir(home) {
Some(PathBuf::from(home))
} else {
log::warn!(
"Home directory for uid {} does not exist or is not a directory: {:?}",
uid,
home
);
None
}
}
None => {
log::warn!("Failed to get user info for uid {}", uid);
None
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_run_cmds_trim_newline() {
assert_eq!(run_cmds_trim_newline("echo -n 123").unwrap(), "123");
assert_eq!(run_cmds_trim_newline("echo 123").unwrap(), "123");
assert_eq!(
run_cmds_trim_newline("whoami").unwrap() + "\n",
run_cmds("whoami").unwrap()
);
}
/// Test get_home_dir_trusted: returns valid path and ignores HOME env var
#[test]
fn test_get_home_dir_trusted() {
let original_home = std::env::var("HOME").ok();
// Set HOME to a fake/malicious path
std::env::set_var("HOME", "/tmp/fake_malicious_home");
let result = get_home_dir_trusted();
// Restore original HOME
match original_home {
Some(home) => std::env::set_var("HOME", home),
None => std::env::remove_var("HOME"),
}
// Verify: returns valid path that is NOT the fake HOME
if let Some(path) = result {
assert!(path.is_absolute(), "Path should be absolute: {:?}", path);
assert!(path.is_dir(), "Path should be a directory: {:?}", path);
assert_ne!(
path.to_string_lossy(),
"/tmp/fake_malicious_home",
"Should not use HOME env var"
);
}
}
/// Test shell_quote with normal strings
#[test]
fn test_shell_quote_normal() {
assert_eq!(shell_quote("hello"), "'hello'");
assert_eq!(shell_quote("/home/user"), "'/home/user'");
}
/// Test shell_quote with spaces
#[test]
fn test_shell_quote_spaces() {
assert_eq!(shell_quote("/home/my user/file"), "'/home/my user/file'");
assert_eq!(shell_quote("path with spaces"), "'path with spaces'");
}
/// Test shell_quote with single quotes (the tricky case)
#[test]
fn test_shell_quote_single_quotes() {
assert_eq!(shell_quote("it's"), "'it'\\''s'");
assert_eq!(shell_quote("don't stop"), "'don'\\''t stop'");
}
/// Test shell_quote with shell metacharacters
#[test]
fn test_shell_quote_metacharacters() {
// These should all be safely quoted
assert_eq!(shell_quote("test;rm -rf /"), "'test;rm -rf /'");
assert_eq!(shell_quote("$(whoami)"), "'$(whoami)'");
assert_eq!(shell_quote("`id`"), "'`id`'");
assert_eq!(shell_quote("a && b"), "'a && b'");
assert_eq!(shell_quote("a | b"), "'a | b'");
}
}
+55
View File
@@ -0,0 +1,55 @@
use crate::ResultType;
use osascript;
use serde_derive::{Deserialize, Serialize};
#[derive(Serialize)]
struct AlertParams {
title: String,
message: String,
alert_type: String,
buttons: Vec<String>,
}
#[derive(Deserialize)]
struct AlertResult {
#[serde(rename = "buttonReturned")]
button: String,
}
/// Firstly run the specified app, then alert a dialog. Return the clicked button value.
///
/// # Arguments
///
/// * `app` - The app to execute the script.
/// * `alert_type` - Alert type. . informational, warning, critical
/// * `title` - The alert title.
/// * `message` - The alert message.
/// * `buttons` - The buttons to show.
pub fn alert(
app: String,
alert_type: String,
title: String,
message: String,
buttons: Vec<String>,
) -> ResultType<String> {
let script = osascript::JavaScript::new(&format!(
"
var App = Application('{}');
App.includeStandardAdditions = true;
return App.displayAlert($params.title, {{
message: $params.message,
'as': $params.alert_type,
buttons: $params.buttons,
}});
",
app
));
let result: AlertResult = script.execute_with_params(AlertParams {
title,
message,
alert_type,
buttons,
})?;
Ok(result.button)
}
+82
View File
@@ -0,0 +1,82 @@
#[cfg(target_os = "linux")]
pub mod linux;
#[cfg(target_os = "macos")]
pub mod macos;
#[cfg(target_os = "windows")]
pub mod windows;
#[cfg(not(debug_assertions))]
use crate::{config::Config, log};
#[cfg(not(debug_assertions))]
use std::process::exit;
#[cfg(not(debug_assertions))]
static mut GLOBAL_CALLBACK: Option<Box<dyn Fn()>> = None;
#[cfg(not(debug_assertions))]
extern "C" fn breakdown_signal_handler(sig: i32) {
let mut stack = vec![];
backtrace::trace(|frame| {
backtrace::resolve_frame(frame, |symbol| {
if let Some(name) = symbol.name() {
stack.push(name.to_string());
}
});
true // keep going to the next frame
});
let mut info = String::default();
if stack.iter().any(|s| {
s.contains(&"nouveau_pushbuf_kick")
|| s.to_lowercase().contains("nvidia")
|| s.contains("gdk_window_end_draw_frame")
|| s.contains("glGetString")
}) {
Config::set_option("allow-always-software-render".to_string(), "Y".to_string());
info = "Always use software rendering will be set.".to_string();
log::info!("{}", info);
}
if stack.iter().any(|s| {
s.to_lowercase().contains("nvidia")
|| s.to_lowercase().contains("amf")
|| s.to_lowercase().contains("mfx")
|| s.contains("cuProfilerStop")
}) {
Config::set_option("enable-hwcodec".to_string(), "N".to_string());
info = "Perhaps hwcodec causing the crash, disable it first".to_string();
log::info!("{}", info);
}
log::error!(
"Got signal {} and exit. stack:\n{}",
sig,
stack.join("\n").to_string()
);
if !info.is_empty() {
#[cfg(target_os = "linux")]
linux::system_message(
"RustDesk",
&format!("Got signal {} and exit.{}", sig, info),
true,
)
.ok();
}
unsafe {
#[allow(static_mut_refs)]
if let Some(callback) = &GLOBAL_CALLBACK {
callback()
}
}
exit(0);
}
#[cfg(not(debug_assertions))]
pub fn register_breakdown_handler<T>(callback: T)
where
T: Fn() + 'static,
{
unsafe {
GLOBAL_CALLBACK = Some(Box::new(callback));
libc::signal(libc::SIGSEGV, breakdown_signal_handler as _);
}
}
+198
View File
@@ -0,0 +1,198 @@
use std::{
collections::VecDeque,
sync::{Arc, Mutex},
time::Instant,
};
use winapi::{
shared::minwindef::{DWORD, FALSE, TRUE},
um::{
handleapi::CloseHandle,
pdh::{
PdhAddEnglishCounterA, PdhCloseQuery, PdhCollectQueryData, PdhCollectQueryDataEx,
PdhGetFormattedCounterValue, PdhOpenQueryA, PDH_FMT_COUNTERVALUE, PDH_FMT_DOUBLE,
PDH_HCOUNTER, PDH_HQUERY,
},
synchapi::{CreateEventA, WaitForSingleObject},
sysinfoapi::VerSetConditionMask,
winbase::{VerifyVersionInfoW, INFINITE, WAIT_OBJECT_0},
winnt::{
HANDLE, OSVERSIONINFOEXW, VER_BUILDNUMBER, VER_GREATER_EQUAL, VER_MAJORVERSION,
VER_MINORVERSION, VER_SERVICEPACKMAJOR, VER_SERVICEPACKMINOR,
},
},
};
lazy_static::lazy_static! {
static ref CPU_USAGE_ONE_MINUTE: Arc<Mutex<Option<(f64, Instant)>>> = Arc::new(Mutex::new(None));
}
// https://github.com/mgostIH/process_list/blob/master/src/windows/mod.rs
#[repr(transparent)]
pub struct RAIIHandle(pub HANDLE);
impl Drop for RAIIHandle {
fn drop(&mut self) {
// This never gives problem except when running under a debugger.
unsafe { CloseHandle(self.0) };
}
}
#[repr(transparent)]
pub(self) struct RAIIPDHQuery(pub PDH_HQUERY);
impl Drop for RAIIPDHQuery {
fn drop(&mut self) {
unsafe { PdhCloseQuery(self.0) };
}
}
pub fn start_cpu_performance_monitor() {
// Code from:
// https://learn.microsoft.com/en-us/windows/win32/perfctrs/collecting-performance-data
// https://learn.microsoft.com/en-us/windows/win32/api/pdh/nf-pdh-pdhcollectquerydataex
// Why value lower than taskManager:
// https://aaron-margosis.medium.com/task-managers-cpu-numbers-are-all-but-meaningless-2d165b421e43
// Therefore we should compare with Precess Explorer rather than taskManager
let f = || unsafe {
// load avg or cpu usage, test with prime95.
// Prefer cpu usage because we can get accurate value from Precess Explorer.
// const COUNTER_PATH: &'static str = "\\System\\Processor Queue Length\0";
const COUNTER_PATH: &'static str = "\\Processor(_total)\\% Processor Time\0";
const SAMPLE_INTERVAL: DWORD = 2; // 2 second
let mut ret;
let mut query: PDH_HQUERY = std::mem::zeroed();
ret = PdhOpenQueryA(std::ptr::null() as _, 0, &mut query);
if ret != 0 {
log::error!("PdhOpenQueryA failed: 0x{:X}", ret);
return;
}
let _query = RAIIPDHQuery(query);
let mut counter: PDH_HCOUNTER = std::mem::zeroed();
ret = PdhAddEnglishCounterA(query, COUNTER_PATH.as_ptr() as _, 0, &mut counter);
if ret != 0 {
log::error!("PdhAddEnglishCounterA failed: 0x{:X}", ret);
return;
}
ret = PdhCollectQueryData(query);
if ret != 0 {
log::error!("PdhCollectQueryData failed: 0x{:X}", ret);
return;
}
let mut _counter_type: DWORD = 0;
let mut counter_value: PDH_FMT_COUNTERVALUE = std::mem::zeroed();
let event = CreateEventA(std::ptr::null_mut(), FALSE, FALSE, std::ptr::null() as _);
if event.is_null() {
log::error!("CreateEventA failed");
return;
}
let _event: RAIIHandle = RAIIHandle(event);
ret = PdhCollectQueryDataEx(query, SAMPLE_INTERVAL, event);
if ret != 0 {
log::error!("PdhCollectQueryDataEx failed: 0x{:X}", ret);
return;
}
let mut queue: VecDeque<f64> = VecDeque::new();
let mut recent_valid: VecDeque<bool> = VecDeque::new();
loop {
// latest one minute
if queue.len() == 31 {
queue.pop_front();
}
if recent_valid.len() == 31 {
recent_valid.pop_front();
}
// allow get value within one minute
if queue.len() > 0 && recent_valid.iter().filter(|v| **v).count() > queue.len() / 2 {
let sum: f64 = queue.iter().map(|f| f.to_owned()).sum();
let avg = sum / (queue.len() as f64);
*CPU_USAGE_ONE_MINUTE.lock().unwrap() = Some((avg, Instant::now()));
} else {
*CPU_USAGE_ONE_MINUTE.lock().unwrap() = None;
}
if WAIT_OBJECT_0 != WaitForSingleObject(event, INFINITE) {
recent_valid.push_back(false);
continue;
}
if PdhGetFormattedCounterValue(
counter,
PDH_FMT_DOUBLE,
&mut _counter_type,
&mut counter_value,
) != 0
|| counter_value.CStatus != 0
{
recent_valid.push_back(false);
continue;
}
queue.push_back(counter_value.u.doubleValue().clone());
recent_valid.push_back(true);
}
};
use std::sync::Once;
static ONCE: Once = Once::new();
ONCE.call_once(|| {
std::thread::spawn(f);
});
}
pub fn cpu_uage_one_minute() -> Option<f64> {
let v = CPU_USAGE_ONE_MINUTE.lock().unwrap().clone();
if let Some((v, instant)) = v {
if instant.elapsed().as_secs() < 30 {
return Some(v);
}
}
None
}
pub fn sync_cpu_usage(cpu_usage: Option<f64>) {
let v = match cpu_usage {
Some(cpu_usage) => Some((cpu_usage, Instant::now())),
None => None,
};
*CPU_USAGE_ONE_MINUTE.lock().unwrap() = v;
log::info!("cpu usage synced: {:?}", cpu_usage);
}
// https://learn.microsoft.com/en-us/windows/win32/sysinfo/targeting-your-application-at-windows-8-1
// https://github.com/nodejs/node-convergence-archive/blob/e11fe0c2777561827cdb7207d46b0917ef3c42a7/deps/uv/src/win/util.c#L780
pub fn is_windows_version_or_greater(
os_major: u32,
os_minor: u32,
build_number: u32,
service_pack_major: u32,
service_pack_minor: u32,
) -> bool {
let mut osvi: OSVERSIONINFOEXW = unsafe { std::mem::zeroed() };
osvi.dwOSVersionInfoSize = std::mem::size_of::<OSVERSIONINFOEXW>() as DWORD;
osvi.dwMajorVersion = os_major as _;
osvi.dwMinorVersion = os_minor as _;
osvi.dwBuildNumber = build_number as _;
osvi.wServicePackMajor = service_pack_major as _;
osvi.wServicePackMinor = service_pack_minor as _;
let result = unsafe {
let mut condition_mask = 0;
let op = VER_GREATER_EQUAL;
condition_mask = VerSetConditionMask(condition_mask, VER_MAJORVERSION, op);
condition_mask = VerSetConditionMask(condition_mask, VER_MINORVERSION, op);
condition_mask = VerSetConditionMask(condition_mask, VER_BUILDNUMBER, op);
condition_mask = VerSetConditionMask(condition_mask, VER_SERVICEPACKMAJOR, op);
condition_mask = VerSetConditionMask(condition_mask, VER_SERVICEPACKMINOR, op);
VerifyVersionInfoW(
&mut osvi as *mut OSVERSIONINFOEXW,
VER_MAJORVERSION
| VER_MINORVERSION
| VER_BUILDNUMBER
| VER_SERVICEPACKMAJOR
| VER_SERVICEPACKMINOR,
condition_mask,
)
};
result == TRUE
}
+1
View File
@@ -0,0 +1 @@
include!(concat!(env!("OUT_DIR"), "/protos/mod.rs"));
+716
View File
@@ -0,0 +1,716 @@
use std::{
io::Error as IoError,
net::{SocketAddr, ToSocketAddrs},
};
use anyhow::bail;
use async_recursion::async_recursion;
use base64::{engine::general_purpose, Engine};
use httparse::{Error as HttpParseError, Response, EMPTY_HEADER};
use thiserror::Error as ThisError;
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, BufStream};
use tokio_native_tls::{native_tls, TlsConnector, TlsStream};
use tokio_rustls::{client::TlsStream as RustlsTlsStream, TlsConnector as RustlsTlsConnector};
use tokio_socks::{tcp::Socks5Stream, IntoTargetAddr, TargetAddr};
use tokio_util::codec::Framed;
use url::Url;
use crate::{
bytes_codec::BytesCodec,
config::Socks5Server,
tcp::{DynTcpStream, FramedStream},
tls::{get_cached_tls_accept_invalid_cert, get_cached_tls_type, upsert_tls_cache, TlsType},
ResultType,
};
#[derive(Debug, ThisError)]
pub enum ProxyError {
#[error("IO Error: {0}")]
IoError(#[from] IoError),
#[error("Target parse error: {0}")]
TargetParseError(String),
#[error("HTTP parse error: {0}")]
HttpParseError(#[from] HttpParseError),
#[error("The maximum response header length is exceeded: {0}")]
MaximumResponseHeaderLengthExceeded(usize),
#[error("The end of file is reached")]
EndOfFile,
#[error("The url is error: {0}")]
UrlBadScheme(String),
#[error("The url parse error: {0}")]
UrlParseScheme(#[from] url::ParseError),
#[error("No HTTP code was found in the response")]
NoHttpCode,
#[error("The HTTP code is not equal 200: {0}")]
HttpCode200(u16),
#[error("The proxy address resolution failed: {0}")]
AddressResolutionFailed(String),
#[error("The native tls error: {0}")]
NativeTlsError(#[from] tokio_native_tls::native_tls::Error),
}
const MAXIMUM_RESPONSE_HEADER_LENGTH: usize = 4096;
/// The maximum HTTP Headers, which can be parsed.
const MAXIMUM_RESPONSE_HEADERS: usize = 16;
const DEFINE_TIME_OUT: u64 = 600;
pub trait IntoUrl {
// Besides parsing as a valid `Url`, the `Url` must be a valid
// `http::Uri`, in that it makes sense to use in a network request.
fn into_url(self) -> Result<Url, ProxyError>;
fn as_str(&self) -> &str;
}
impl IntoUrl for Url {
fn into_url(self) -> Result<Url, ProxyError> {
if self.has_host() {
Ok(self)
} else {
Err(ProxyError::UrlBadScheme(self.to_string()))
}
}
fn as_str(&self) -> &str {
self.as_ref()
}
}
impl<'a> IntoUrl for &'a str {
fn into_url(self) -> Result<Url, ProxyError> {
Url::parse(self)
.map_err(ProxyError::UrlParseScheme)?
.into_url()
}
fn as_str(&self) -> &str {
self
}
}
impl<'a> IntoUrl for &'a String {
fn into_url(self) -> Result<Url, ProxyError> {
(&**self).into_url()
}
fn as_str(&self) -> &str {
self.as_ref()
}
}
impl<'a> IntoUrl for String {
fn into_url(self) -> Result<Url, ProxyError> {
(&*self).into_url()
}
fn as_str(&self) -> &str {
self.as_ref()
}
}
#[derive(Clone)]
pub struct Auth {
user_name: String,
password: String,
}
impl Auth {
fn get_proxy_authorization(&self) -> String {
format!(
"Proxy-Authorization: Basic {}\r\n",
self.get_basic_authorization()
)
}
pub fn get_basic_authorization(&self) -> String {
let authorization = format!("{}:{}", &self.user_name, &self.password);
general_purpose::STANDARD.encode(authorization.as_bytes())
}
pub fn username(&self) -> &str {
&self.user_name
}
pub fn password(&self) -> &str {
&self.password
}
}
#[derive(Clone)]
pub enum ProxyScheme {
Http {
auth: Option<Auth>,
host: String,
},
Https {
auth: Option<Auth>,
host: String,
},
Socks5 {
addr: SocketAddr,
auth: Option<Auth>,
remote_dns: bool,
},
}
impl ProxyScheme {
pub fn maybe_auth(&self) -> Option<&Auth> {
match self {
ProxyScheme::Http { auth, .. }
| ProxyScheme::Https { auth, .. }
| ProxyScheme::Socks5 { auth, .. } => auth.as_ref(),
}
}
fn socks5(addr: SocketAddr) -> Result<Self, ProxyError> {
Ok(ProxyScheme::Socks5 {
addr,
auth: None,
remote_dns: false,
})
}
fn http(host: &str) -> Result<Self, ProxyError> {
Ok(ProxyScheme::Http {
auth: None,
host: host.to_string(),
})
}
fn https(host: &str) -> Result<Self, ProxyError> {
Ok(ProxyScheme::Https {
auth: None,
host: host.to_string(),
})
}
fn set_basic_auth<T: Into<String>, U: Into<String>>(&mut self, username: T, password: U) {
let auth = Auth {
user_name: username.into(),
password: password.into(),
};
match self {
ProxyScheme::Http { auth: a, .. } => *a = Some(auth),
ProxyScheme::Https { auth: a, .. } => *a = Some(auth),
ProxyScheme::Socks5 { auth: a, .. } => *a = Some(auth),
}
}
fn parse(url: Url) -> Result<Self, ProxyError> {
use url::Position;
// Resolve URL to a host and port
let to_addr = || {
let addrs = url.socket_addrs(|| match url.scheme() {
"socks5" => Some(1080),
_ => None,
})?;
addrs
.into_iter()
.next()
.ok_or_else(|| ProxyError::UrlParseScheme(url::ParseError::EmptyHost))
};
let mut scheme: Self = match url.scheme() {
"http" => Self::http(&url[Position::BeforeHost..Position::AfterPort])?,
"https" => Self::https(&url[Position::BeforeHost..Position::AfterPort])?,
"socks5" => Self::socks5(to_addr()?)?,
e => return Err(ProxyError::UrlBadScheme(e.to_string())),
};
if let Some(pwd) = url.password() {
let username = url.username();
scheme.set_basic_auth(username, pwd);
}
Ok(scheme)
}
pub async fn socket_addrs(&self) -> Result<SocketAddr, ProxyError> {
log::trace!("Resolving socket address");
match self {
ProxyScheme::Http { host, .. } => self.resolve_host(host, 80).await,
ProxyScheme::Https { host, .. } => self.resolve_host(host, 443).await,
ProxyScheme::Socks5 { addr, .. } => Ok(addr.clone()),
}
}
async fn resolve_host(&self, host: &str, default_port: u16) -> Result<SocketAddr, ProxyError> {
let (host_str, port) = match host.split_once(':') {
Some((h, p)) => (h, p.parse::<u16>().ok()),
None => (host, None),
};
let addr = (host_str, port.unwrap_or(default_port))
.to_socket_addrs()?
.next()
.ok_or_else(|| ProxyError::AddressResolutionFailed(host.to_string()))?;
Ok(addr)
}
pub fn get_domain(&self) -> Result<String, ProxyError> {
match self {
ProxyScheme::Http { host, .. } | ProxyScheme::Https { host, .. } => {
let domain = host
.split(':')
.next()
.ok_or_else(|| ProxyError::AddressResolutionFailed(host.clone()))?;
Ok(domain.to_string())
}
ProxyScheme::Socks5 { addr, .. } => match addr {
SocketAddr::V4(addr_v4) => Ok(addr_v4.ip().to_string()),
SocketAddr::V6(addr_v6) => Ok(addr_v6.ip().to_string()),
},
}
}
pub fn get_host_and_port(&self) -> Result<String, ProxyError> {
match self {
ProxyScheme::Http { host, .. } => Ok(self.append_default_port(host, 80)),
ProxyScheme::Https { host, .. } => Ok(self.append_default_port(host, 443)),
ProxyScheme::Socks5 { addr, .. } => Ok(format!("{}", addr)),
}
}
fn append_default_port(&self, host: &str, default_port: u16) -> String {
if host.contains(':') {
host.to_string()
} else {
format!("{}:{}", host, default_port)
}
}
}
pub trait IntoProxyScheme {
fn into_proxy_scheme(self) -> Result<ProxyScheme, ProxyError>;
}
impl<S: IntoUrl> IntoProxyScheme for S {
fn into_proxy_scheme(self) -> Result<ProxyScheme, ProxyError> {
// validate the URL
let url = match self.as_str().into_url() {
Ok(ok) => ok,
Err(e) => {
match e {
// If the string does not contain protocol headers, try to parse it using the socks5 protocol
ProxyError::UrlParseScheme(_source) => {
let try_this = format!("socks5://{}", self.as_str());
try_this.into_url()?
}
_ => {
return Err(e);
}
}
}
};
ProxyScheme::parse(url)
}
}
impl IntoProxyScheme for ProxyScheme {
fn into_proxy_scheme(self) -> Result<ProxyScheme, ProxyError> {
Ok(self)
}
}
#[derive(Clone)]
pub struct Proxy {
pub intercept: ProxyScheme,
ms_timeout: u64,
}
impl Proxy {
pub fn new<U: IntoProxyScheme>(proxy_scheme: U, ms_timeout: u64) -> Result<Self, ProxyError> {
Ok(Self {
intercept: proxy_scheme.into_proxy_scheme()?,
ms_timeout,
})
}
pub fn is_http_or_https(&self) -> bool {
return match self.intercept {
ProxyScheme::Socks5 { .. } => false,
_ => true,
};
}
pub fn from_conf(conf: &Socks5Server, ms_timeout: Option<u64>) -> Result<Self, ProxyError> {
let mut proxy;
match ms_timeout {
None => {
proxy = Self::new(&conf.proxy, DEFINE_TIME_OUT)?;
}
Some(time_out) => {
proxy = Self::new(&conf.proxy, time_out)?;
}
}
if !conf.password.is_empty() && !conf.username.is_empty() {
proxy = proxy.basic_auth(&conf.username, &conf.password);
}
Ok(proxy)
}
pub async fn proxy_addrs(&self) -> Result<SocketAddr, ProxyError> {
self.intercept.socket_addrs().await
}
fn basic_auth(mut self, username: &str, password: &str) -> Proxy {
self.intercept.set_basic_auth(username, password);
self
}
async fn new_stream(
&self,
local: SocketAddr,
proxy: SocketAddr,
) -> ResultType<tokio::net::TcpStream> {
let stream = super::timeout(
self.ms_timeout,
crate::tcp::new_socket(local, true)?.connect(proxy),
)
.await??;
stream.set_nodelay(true).ok();
Ok(stream)
}
pub async fn connect<'t, T>(
&self,
target: T,
local_addr: Option<SocketAddr>,
) -> ResultType<FramedStream>
where
T: IntoTargetAddr<'t>,
{
log::trace!("Connect to proxy server");
let proxy = self.proxy_addrs().await?;
let target_addr = target
.into_target_addr()
.map_err(|e| ProxyError::TargetParseError(e.to_string()))?;
let local = if let Some(addr) = local_addr {
addr
} else {
crate::config::Config::get_any_listen_addr(proxy.is_ipv4())
};
let stream = self.new_stream(local, proxy).await?;
let addr = stream.local_addr()?;
return match self.intercept {
ProxyScheme::Http { .. } => {
log::trace!("Connect to remote http proxy server: {}", proxy);
let stream =
super::timeout(self.ms_timeout, self.http_connect(stream, &target_addr))
.await??;
Ok(FramedStream(
Framed::new(DynTcpStream(Box::new(stream)), BytesCodec::new()),
addr,
None,
0,
))
}
ProxyScheme::Https { .. } => {
log::trace!("Connect to remote https proxy server: {}", proxy);
let url = format!("https://{}", self.intercept.get_host_and_port()?);
let tls_type = get_cached_tls_type(&url);
let danger_accept_invalid_cert = get_cached_tls_accept_invalid_cert(&url);
let stream = match tls_type.unwrap_or(TlsType::Rustls) {
TlsType::Rustls => {
self.https_connect_rustls_wrap_danger(
&url,
local,
proxy,
Some(stream),
&target_addr,
tls_type.is_some(),
danger_accept_invalid_cert,
danger_accept_invalid_cert,
)
.await?
}
TlsType::NativeTls => {
self.https_connect_nativetls_wrap_danger(
&url,
local,
proxy,
&target_addr,
danger_accept_invalid_cert,
)
.await?
}
_ => {
// Unreachable
crate::bail!("Unreachable, TlsType::Plain in HTTPS proxy");
}
};
Ok(FramedStream(
Framed::new(stream, BytesCodec::new()),
addr,
None,
0,
))
}
ProxyScheme::Socks5 { .. } => {
log::trace!("Connect to remote socket5 proxy server: {}", proxy);
let stream = if let Some(auth) = self.intercept.maybe_auth() {
super::timeout(
self.ms_timeout,
Socks5Stream::connect_with_password_and_socket(
stream,
target_addr,
&auth.user_name,
&auth.password,
),
)
.await??
} else {
super::timeout(
self.ms_timeout,
Socks5Stream::connect_with_socket(stream, target_addr),
)
.await??
};
Ok(FramedStream(
Framed::new(DynTcpStream(Box::new(stream)), BytesCodec::new()),
addr,
None,
0,
))
}
};
}
async fn https_connect_nativetls_wrap_danger<'a>(
&self,
url: &str,
local: SocketAddr,
proxy: SocketAddr,
target_addr: &TargetAddr<'a>,
danger_accept_invalid_cert: Option<bool>,
) -> ResultType<DynTcpStream> {
let stream = self.new_stream(local, proxy).await?;
let s = super::timeout(
self.ms_timeout,
self.https_connect_nativetls(
stream,
&target_addr,
danger_accept_invalid_cert.unwrap_or(false),
),
)
.await??;
upsert_tls_cache(
url,
TlsType::NativeTls,
danger_accept_invalid_cert.unwrap_or(false),
);
Ok(DynTcpStream(Box::new(s)))
}
pub async fn https_connect_nativetls<'a, Input>(
&self,
io: Input,
target_addr: &TargetAddr<'a>,
danger_accept_invalid_cert: bool,
) -> Result<BufStream<TlsStream<Input>>, ProxyError>
where
Input: AsyncRead + AsyncWrite + Unpin,
{
let mut tls_connector_builder = native_tls::TlsConnector::builder();
if danger_accept_invalid_cert {
tls_connector_builder.danger_accept_invalid_certs(true);
}
let tls_connector = TlsConnector::from(tls_connector_builder.build()?);
let stream = tls_connector
.connect(&self.intercept.get_domain()?, io)
.await?;
self.http_connect(stream, target_addr).await
}
#[async_recursion]
async fn https_connect_rustls_wrap_danger<'a>(
&self,
url: &str,
local: SocketAddr,
proxy: SocketAddr,
stream: Option<tokio::net::TcpStream>,
target_addr: &TargetAddr<'a>,
is_tls_type_cached: bool,
danger_accept_invalid_cert: Option<bool>,
origin_danger_accept_invalid_cert: Option<bool>,
) -> ResultType<DynTcpStream> {
let stream = stream.unwrap_or(self.new_stream(local, proxy).await?);
match super::timeout(
self.ms_timeout,
self.https_connect_rustls(
stream,
target_addr,
danger_accept_invalid_cert.unwrap_or(false),
),
)
.await?
{
Ok(s) => {
upsert_tls_cache(
&url,
TlsType::Rustls,
danger_accept_invalid_cert.unwrap_or(false),
);
Ok(DynTcpStream(Box::new(s)))
}
Err(e) => {
// NOTE: Maybe it's better to check if the error is related to TLS here. (ProxyError::IoError(e), or ProxyError::NativeTlsError(e))
// But we can only get the error when the TLS protocol is TLSv1.1.
// The error message of the following is unclear:
// https://github.com/rustdesk/rustdesk-server-pro/issues/189#issuecomment-1895701480
// So we just try to fallback unconditionally here.
//
// If the protocol is TLS 1.1, the error is:
// 1. "IO Error: received fatal alert: ProtocolVersion"
// 2. "IO Error: An existing connection was forcibly closed by the remote host. (os error 10054)" on Windows sometimes.
//
// If the cert verification fails, the error is:
// "IO Error: invalid peer certificate: UnknownIssuer"
let s = if danger_accept_invalid_cert.is_none() {
log::warn!(
"Falling back to rustls-tls (accept invalid cert) for HTTPS proxy server."
);
self.https_connect_rustls_wrap_danger(
&url,
local,
proxy,
None,
target_addr,
is_tls_type_cached,
Some(true),
origin_danger_accept_invalid_cert,
)
.await?
} else if !is_tls_type_cached {
log::warn!("Falling back to native-tls for HTTPS proxy server.");
self.https_connect_nativetls_wrap_danger(
&url,
local,
proxy,
&target_addr,
origin_danger_accept_invalid_cert,
)
.await?
} else {
log::error!(
"Failed to connect to HTTPS proxy server with native-tls: {:?}.",
e
);
bail!(e)
};
Ok(s)
}
}
}
pub async fn https_connect_rustls<'a, Input>(
&self,
io: Input,
target_addr: &TargetAddr<'a>,
danger_accept_invalid_cert: bool,
) -> Result<BufStream<RustlsTlsStream<Input>>, ProxyError>
where
Input: AsyncRead + AsyncWrite + Unpin,
{
use std::convert::TryFrom;
let url_domain = self.intercept.get_domain()?;
let domain = rustls_pki_types::ServerName::try_from(url_domain.as_str())
.map_err(|e| ProxyError::AddressResolutionFailed(e.to_string()))?
.to_owned();
let client_config = crate::verifier::client_config(danger_accept_invalid_cert)
.map_err(|e| ProxyError::IoError(std::io::Error::other(e)))?;
let tls_connector = RustlsTlsConnector::from(std::sync::Arc::new(client_config));
let stream = tls_connector.connect(domain, io).await?;
self.http_connect(stream, target_addr).await
}
pub async fn http_connect<'a, Input>(
&self,
io: Input,
target_addr: &TargetAddr<'a>,
) -> Result<BufStream<Input>, ProxyError>
where
Input: AsyncRead + AsyncWrite + Unpin,
{
let mut stream = BufStream::new(io);
let (domain, port) = get_domain_and_port(target_addr)?;
let request = self.make_request(&domain, port);
stream.write_all(request.as_bytes()).await?;
stream.flush().await?;
recv_and_check_response(&mut stream).await?;
Ok(stream)
}
fn make_request(&self, host: &str, port: u16) -> String {
let mut request = format!(
"CONNECT {host}:{port} HTTP/1.1\r\nHost: {host}:{port}\r\n",
host = host,
port = port
);
if let Some(auth) = self.intercept.maybe_auth() {
request = format!("{}{}", request, auth.get_proxy_authorization());
}
request.push_str("\r\n");
request
}
}
fn get_domain_and_port<'a>(target_addr: &TargetAddr<'a>) -> Result<(String, u16), ProxyError> {
match target_addr {
tokio_socks::TargetAddr::Ip(addr) => Ok((addr.ip().to_string(), addr.port())),
tokio_socks::TargetAddr::Domain(name, port) => Ok((name.to_string(), *port)),
}
}
async fn get_response<IO>(stream: &mut BufStream<IO>) -> Result<String, ProxyError>
where
IO: AsyncRead + AsyncWrite + Unpin,
{
use tokio::io::AsyncBufReadExt;
let mut response = String::new();
loop {
if stream.read_line(&mut response).await? == 0 {
return Err(ProxyError::EndOfFile);
}
if MAXIMUM_RESPONSE_HEADER_LENGTH < response.len() {
return Err(ProxyError::MaximumResponseHeaderLengthExceeded(
response.len(),
));
}
if response.ends_with("\r\n\r\n") {
return Ok(response);
}
}
}
async fn recv_and_check_response<IO>(stream: &mut BufStream<IO>) -> Result<(), ProxyError>
where
IO: AsyncRead + AsyncWrite + Unpin,
{
let response_string = get_response(stream).await?;
let mut response_headers = [EMPTY_HEADER; MAXIMUM_RESPONSE_HEADERS];
let mut response = Response::new(&mut response_headers);
let response_bytes = response_string.into_bytes();
response.parse(&response_bytes)?;
return match response.code {
Some(code) => {
if code == 200 {
Ok(())
} else {
Err(ProxyError::HttpCode200(code))
}
}
None => Err(ProxyError::NoHttpCode),
};
}
+348
View File
@@ -0,0 +1,348 @@
#[cfg(feature = "webrtc")]
use crate::webrtc::{self, is_webrtc_endpoint};
use crate::{
config::{Config, NetworkType},
tcp::FramedStream,
udp::FramedSocket,
websocket::{self, check_ws, is_ws_endpoint},
ResultType, Stream,
};
use anyhow::Context;
use std::{net::SocketAddr, sync::Arc};
use tokio::net::{ToSocketAddrs, UdpSocket};
use tokio_socks::{IntoTargetAddr, TargetAddr};
#[inline]
pub fn check_port<T: std::string::ToString>(host: T, port: i32) -> String {
let host = host.to_string();
if crate::is_ipv6_str(&host) {
if host.starts_with('[') {
return host;
}
return format!("[{host}]:{port}");
}
if !host.contains(':') {
return format!("{host}:{port}");
}
host
}
#[inline]
pub fn increase_port<T: std::string::ToString>(host: T, offset: i32) -> String {
let host = host.to_string();
if crate::is_ipv6_str(&host) {
if host.starts_with('[') {
let tmp: Vec<&str> = host.split("]:").collect();
if tmp.len() == 2 {
let port: i32 = tmp[1].parse().unwrap_or(0);
if port > 0 {
return format!("{}]:{}", tmp[0], port + offset);
}
}
}
} else if host.contains(':') {
let tmp: Vec<&str> = host.split(':').collect();
if tmp.len() == 2 {
let port: i32 = tmp[1].parse().unwrap_or(0);
if port > 0 {
return format!("{}:{}", tmp[0], port + offset);
}
}
}
host
}
pub fn split_host_port<T: std::string::ToString>(host: T) -> Option<(String, i32)> {
let host = host.to_string();
if crate::is_ipv6_str(&host) {
if host.starts_with('[') {
let tmp: Vec<&str> = host.split("]:").collect();
if tmp.len() == 2 {
let port: i32 = tmp[1].parse().unwrap_or(0);
if port > 0 {
return Some((format!("{}]", tmp[0]), port));
}
}
}
} else if host.contains(':') {
let tmp: Vec<&str> = host.split(':').collect();
if tmp.len() == 2 {
let port: i32 = tmp[1].parse().unwrap_or(0);
if port > 0 {
return Some((tmp[0].to_string(), port));
}
}
}
None
}
pub fn test_if_valid_server(host: &str, test_with_proxy: bool) -> String {
let host = check_port(host, 0);
use std::net::ToSocketAddrs;
if test_with_proxy && NetworkType::ProxySocks == Config::get_network_type() {
test_if_valid_server_for_proxy_(&host)
} else {
match host.to_socket_addrs() {
Err(err) => err.to_string(),
Ok(_) => "".to_owned(),
}
}
}
#[inline]
pub fn test_if_valid_server_for_proxy_(host: &str) -> String {
// `&host.into_target_addr()` is defined in `tokio-socs`, but is a common pattern for testing,
// it can be used for both `socks` and `http` proxy.
match &host.into_target_addr() {
Err(err) => err.to_string(),
Ok(_) => "".to_owned(),
}
}
pub trait IsResolvedSocketAddr {
fn resolve(&self) -> Option<&SocketAddr>;
}
impl IsResolvedSocketAddr for SocketAddr {
fn resolve(&self) -> Option<&SocketAddr> {
Some(self)
}
}
impl IsResolvedSocketAddr for String {
fn resolve(&self) -> Option<&SocketAddr> {
None
}
}
impl IsResolvedSocketAddr for &str {
fn resolve(&self) -> Option<&SocketAddr> {
None
}
}
// This function checks if the target is a websocket endpoint and connects accordingly.
#[inline]
pub async fn connect_tcp<
't,
T: IntoTargetAddr<'t> + ToSocketAddrs + IsResolvedSocketAddr + std::fmt::Display,
>(
target: T,
ms_timeout: u64,
) -> ResultType<crate::Stream> {
#[cfg(feature = "webrtc")]
if is_webrtc_endpoint(&target.to_string()) {
return Ok(Stream::WebRTC(
webrtc::WebRTCStream::new(&target.to_string(), false, ms_timeout).await?,
));
}
let target_str = check_ws(&target.to_string());
if is_ws_endpoint(&target_str) {
return Ok(Stream::WebSocket(
websocket::WsFramedStream::new(target_str, None, None, ms_timeout).await?,
));
}
connect_tcp_local(target, None, ms_timeout).await
}
// This function connects directly to the target without checking for websocket endpoints.
pub async fn connect_tcp_local<
't,
T: IntoTargetAddr<'t> + ToSocketAddrs + IsResolvedSocketAddr + std::fmt::Display,
>(
target: T,
local: Option<SocketAddr>,
ms_timeout: u64,
) -> ResultType<Stream> {
if let Some(conf) = Config::get_socks() {
return Ok(Stream::Tcp(
FramedStream::connect(target, local, &conf, ms_timeout).await?,
));
}
if let Some(target_addr) = target.resolve() {
if let Some(local_addr) = local {
if local_addr.is_ipv6() && target_addr.is_ipv4() {
let resolved_target = query_nip_io(target_addr).await?;
return Ok(Stream::Tcp(
FramedStream::new(resolved_target, Some(local_addr), ms_timeout).await?,
));
}
}
}
Ok(Stream::Tcp(
FramedStream::new(target, local, ms_timeout).await?,
))
}
#[inline]
pub fn is_ipv4(target: &TargetAddr<'_>) -> bool {
match target {
TargetAddr::Ip(addr) => addr.is_ipv4(),
_ => true,
}
}
#[inline]
pub async fn query_nip_io(addr: &SocketAddr) -> ResultType<SocketAddr> {
tokio::net::lookup_host(format!("{}.nip.io:{}", addr.ip(), addr.port()))
.await?
.find(|x| x.is_ipv6())
.context("Failed to get ipv6 from nip.io")
}
#[inline]
pub fn ipv4_to_ipv6(addr: String, ipv4: bool) -> String {
if !ipv4 && crate::is_ipv4_str(&addr) {
if let Some(ip) = addr.split(':').next() {
return addr.replace(ip, &format!("{ip}.nip.io"));
}
}
addr
}
async fn test_target(target: &str) -> ResultType<SocketAddr> {
if let Ok(Ok(s)) = super::timeout(1000, tokio::net::TcpStream::connect(target)).await {
if let Ok(addr) = s.peer_addr() {
return Ok(addr);
}
}
tokio::net::lookup_host(target)
.await?
.next()
.context(format!("Failed to look up host for {target}"))
}
#[inline]
pub async fn new_direct_udp_for(target: &str) -> ResultType<(Arc<UdpSocket>, SocketAddr)> {
let peer_addr = test_target(target).await?;
let local_addr = Config::get_any_listen_addr(peer_addr.is_ipv4());
let socket = UdpSocket::bind(local_addr).await?;
Ok((Arc::new(socket), peer_addr))
}
#[inline]
pub async fn new_udp_for(
target: &str,
ms_timeout: u64,
) -> ResultType<(FramedSocket, TargetAddr<'static>)> {
let (ipv4, target) = if NetworkType::Direct == Config::get_network_type() {
let addr = test_target(target).await?;
(addr.is_ipv4(), addr.into_target_addr()?)
} else {
(true, target.into_target_addr()?)
};
Ok((
new_udp(Config::get_any_listen_addr(ipv4), ms_timeout).await?,
target.to_owned(),
))
}
async fn new_udp<T: ToSocketAddrs>(local: T, ms_timeout: u64) -> ResultType<FramedSocket> {
match Config::get_socks() {
None => Ok(FramedSocket::new(local).await?),
Some(conf) => {
let socket = FramedSocket::new_proxy(
conf.proxy.as_str(),
local,
conf.username.as_str(),
conf.password.as_str(),
ms_timeout,
)
.await?;
Ok(socket)
}
}
}
pub async fn rebind_udp_for(
target: &str,
) -> ResultType<Option<(FramedSocket, TargetAddr<'static>)>> {
if Config::get_network_type() != NetworkType::Direct {
return Ok(None);
}
let addr = test_target(target).await?;
let v4 = addr.is_ipv4();
Ok(Some((
FramedSocket::new(Config::get_any_listen_addr(v4)).await?,
addr.into_target_addr()?.to_owned(),
)))
}
#[cfg(test)]
mod tests {
use std::net::ToSocketAddrs;
use super::*;
#[test]
fn test_nat64() {
test_nat64_async();
}
#[tokio::main(flavor = "current_thread")]
async fn test_nat64_async() {
assert_eq!(ipv4_to_ipv6("1.1.1.1".to_owned(), true), "1.1.1.1");
assert_eq!(ipv4_to_ipv6("1.1.1.1".to_owned(), false), "1.1.1.1.nip.io");
assert_eq!(
ipv4_to_ipv6("1.1.1.1:8080".to_owned(), false),
"1.1.1.1.nip.io:8080"
);
assert_eq!(
ipv4_to_ipv6("rustdesk.com".to_owned(), false),
"rustdesk.com"
);
if ("rustdesk.com:80")
.to_socket_addrs()
.unwrap()
.next()
.unwrap()
.is_ipv6()
{
assert!(query_nip_io(&"1.1.1.1:80".parse().unwrap())
.await
.unwrap()
.is_ipv6());
return;
}
assert!(query_nip_io(&"1.1.1.1:80".parse().unwrap()).await.is_err());
}
#[test]
fn test_test_if_valid_server() {
assert!(!test_if_valid_server("a", false).is_empty());
// on Linux, "1" is resolved to "0.0.0.1"
assert!(test_if_valid_server("1.1.1.1", false).is_empty());
assert!(test_if_valid_server("1.1.1.1:1", false).is_empty());
assert!(test_if_valid_server("microsoft.com", false).is_empty());
assert!(test_if_valid_server("microsoft.com:1", false).is_empty());
// with proxy
// `:0` indicates `let host = check_port(host, 0);` is called.
assert!(test_if_valid_server_for_proxy_("a:0").is_empty());
assert!(test_if_valid_server_for_proxy_("1.1.1.1:0").is_empty());
assert!(test_if_valid_server_for_proxy_("1.1.1.1:1").is_empty());
assert!(test_if_valid_server_for_proxy_("abc.com:0").is_empty());
assert!(test_if_valid_server_for_proxy_("abcd.com:1").is_empty());
}
#[test]
fn test_check_port() {
assert_eq!(check_port("[1:2]:12", 32), "[1:2]:12");
assert_eq!(check_port("1:2", 32), "[1:2]:32");
assert_eq!(check_port("z1:2", 32), "z1:2");
assert_eq!(check_port("1.1.1.1", 32), "1.1.1.1:32");
assert_eq!(check_port("1.1.1.1:32", 32), "1.1.1.1:32");
assert_eq!(check_port("test.com:32", 0), "test.com:32");
assert_eq!(increase_port("[1:2]:12", 1), "[1:2]:13");
assert_eq!(increase_port("1.2.2.4:12", 1), "1.2.2.4:13");
assert_eq!(increase_port("1.2.2.4", 1), "1.2.2.4");
assert_eq!(increase_port("test.com", 1), "test.com");
assert_eq!(increase_port("test.com:13", 4), "test.com:17");
assert_eq!(increase_port("1:13", 4), "1:13");
assert_eq!(increase_port("22:1:13", 4), "22:1:13");
assert_eq!(increase_port("z1:2", 1), "z1:3");
}
}
+149
View File
@@ -0,0 +1,149 @@
use crate::{config, tcp, websocket, ResultType};
#[cfg(feature = "webrtc")]
use crate::webrtc;
use sodiumoxide::crypto::secretbox::Key;
use std::net::SocketAddr;
use tokio::net::TcpStream;
// support Websocket and tcp.
pub enum Stream {
#[cfg(feature = "webrtc")]
WebRTC(webrtc::WebRTCStream),
WebSocket(websocket::WsFramedStream),
Tcp(tcp::FramedStream),
}
impl Stream {
#[inline]
pub fn set_send_timeout(&mut self, ms: u64) {
match self {
#[cfg(feature = "webrtc")]
Stream::WebRTC(s) => s.set_send_timeout(ms),
Stream::WebSocket(s) => s.set_send_timeout(ms),
Stream::Tcp(s) => s.set_send_timeout(ms),
}
}
#[inline]
pub fn set_raw(&mut self) {
match self {
#[cfg(feature = "webrtc")]
Stream::WebRTC(s) => s.set_raw(),
Stream::WebSocket(s) => s.set_raw(),
Stream::Tcp(s) => s.set_raw(),
}
}
#[inline]
pub async fn send_bytes(&mut self, bytes: bytes::Bytes) -> ResultType<()> {
match self {
#[cfg(feature = "webrtc")]
Stream::WebRTC(s) => s.send_bytes(bytes).await,
Stream::WebSocket(s) => s.send_bytes(bytes).await,
Stream::Tcp(s) => s.send_bytes(bytes).await,
}
}
#[inline]
pub async fn send_raw(&mut self, bytes: Vec<u8>) -> ResultType<()> {
match self {
#[cfg(feature = "webrtc")]
Stream::WebRTC(s) => s.send_raw(bytes).await,
Stream::WebSocket(s) => s.send_raw(bytes).await,
Stream::Tcp(s) => s.send_raw(bytes).await,
}
}
#[inline]
pub fn set_key(&mut self, key: Key) {
match self {
#[cfg(feature = "webrtc")]
Stream::WebRTC(s) => s.set_key(key),
Stream::WebSocket(s) => s.set_key(key),
Stream::Tcp(s) => s.set_key(key),
}
}
#[inline]
pub fn is_secured(&self) -> bool {
match self {
#[cfg(feature = "webrtc")]
Stream::WebRTC(s) => s.is_secured(),
Stream::WebSocket(s) => s.is_secured(),
Stream::Tcp(s) => s.is_secured(),
}
}
#[inline]
pub async fn next_timeout(
&mut self,
timeout: u64,
) -> Option<Result<bytes::BytesMut, std::io::Error>> {
match self {
#[cfg(feature = "webrtc")]
Stream::WebRTC(s) => s.next_timeout(timeout).await,
Stream::WebSocket(s) => s.next_timeout(timeout).await,
Stream::Tcp(s) => s.next_timeout(timeout).await,
}
}
/// establish connect from websocket
#[inline]
pub async fn connect_websocket(
url: impl AsRef<str>,
local_addr: Option<SocketAddr>,
proxy_conf: Option<&config::Socks5Server>,
timeout_ms: u64,
) -> ResultType<Self> {
let ws_stream =
websocket::WsFramedStream::new(url, local_addr, proxy_conf, timeout_ms).await?;
log::debug!("WebSocket connection established");
Ok(Self::WebSocket(ws_stream))
}
/// send message
#[inline]
pub async fn send(&mut self, msg: &impl protobuf::Message) -> ResultType<()> {
match self {
#[cfg(feature = "webrtc")]
Self::WebRTC(s) => s.send(msg).await,
Self::WebSocket(ws) => ws.send(msg).await,
Self::Tcp(tcp) => tcp.send(msg).await,
}
}
/// receive message
#[inline]
pub async fn next(&mut self) -> Option<Result<bytes::BytesMut, std::io::Error>> {
match self {
#[cfg(feature = "webrtc")]
Self::WebRTC(s) => s.next().await,
Self::WebSocket(ws) => ws.next().await,
Self::Tcp(tcp) => tcp.next().await,
}
}
#[inline]
pub fn local_addr(&self) -> SocketAddr {
match self {
#[cfg(feature = "webrtc")]
Self::WebRTC(s) => s.local_addr(),
Self::WebSocket(ws) => ws.local_addr(),
Self::Tcp(tcp) => tcp.local_addr(),
}
}
#[inline]
pub fn from(stream: TcpStream, stream_addr: SocketAddr) -> Self {
Self::Tcp(tcp::FramedStream::from(stream, stream_addr))
}
#[inline]
#[cfg(feature = "webrtc")]
pub fn get_webrtc_stream(&self) -> Option<webrtc::WebRTCStream> {
match self {
Self::WebRTC(s) => Some(s.clone()),
_ => None,
}
}
}
+344
View File
@@ -0,0 +1,344 @@
use crate::{bail, bytes_codec::BytesCodec, ResultType, config::Socks5Server, proxy::Proxy};
use anyhow::Context as AnyhowCtx;
use bytes::{BufMut, Bytes, BytesMut};
use futures::{SinkExt, StreamExt};
use protobuf::Message;
use sodiumoxide::crypto::{
box_,
secretbox::{self, Key, Nonce},
};
use std::{
io::{self, Error, ErrorKind},
net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr},
ops::{Deref, DerefMut},
pin::Pin,
task::{Context, Poll},
};
use tokio::{
io::{AsyncRead, AsyncWrite, ReadBuf},
net::{lookup_host, TcpListener, TcpSocket, ToSocketAddrs},
};
use tokio_socks::IntoTargetAddr;
use tokio_util::codec::Framed;
pub trait TcpStreamTrait: AsyncRead + AsyncWrite + Unpin {}
pub struct DynTcpStream(pub Box<dyn TcpStreamTrait + Send + Sync>);
#[derive(Clone)]
pub struct Encrypt(pub Key, pub u64, pub u64);
pub struct FramedStream(
pub Framed<DynTcpStream, BytesCodec>,
pub SocketAddr,
pub Option<Encrypt>,
pub u64,
);
impl Deref for FramedStream {
type Target = Framed<DynTcpStream, BytesCodec>;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl DerefMut for FramedStream {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}
impl Deref for DynTcpStream {
type Target = Box<dyn TcpStreamTrait + Send + Sync>;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl DerefMut for DynTcpStream {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}
pub(crate) fn new_socket(addr: std::net::SocketAddr, reuse: bool) -> Result<TcpSocket, std::io::Error> {
let socket = match addr {
std::net::SocketAddr::V4(..) => TcpSocket::new_v4()?,
std::net::SocketAddr::V6(..) => TcpSocket::new_v6()?,
};
if reuse {
// windows has no reuse_port, but its reuse_address
// almost equals to unix's reuse_port + reuse_address,
// though may introduce nondeterministic behavior
// illumos has no support for SO_REUSEPORT
#[cfg(all(unix, not(target_os = "illumos")))]
socket.set_reuseport(true).ok();
socket.set_reuseaddr(true).ok();
}
socket.bind(addr)?;
Ok(socket)
}
impl FramedStream {
pub async fn new<T: ToSocketAddrs + std::fmt::Display>(
remote_addr: T,
local_addr: Option<SocketAddr>,
ms_timeout: u64,
) -> ResultType<Self> {
for remote_addr in lookup_host(&remote_addr).await? {
let local = if let Some(addr) = local_addr {
addr
} else {
crate::config::Config::get_any_listen_addr(remote_addr.is_ipv4())
};
if let Ok(socket) = new_socket(local, true) {
if let Ok(Ok(stream)) =
super::timeout(ms_timeout, socket.connect(remote_addr)).await
{
stream.set_nodelay(true).ok();
let addr = stream.local_addr()?;
return Ok(Self(
Framed::new(DynTcpStream(Box::new(stream)), BytesCodec::new()),
addr,
None,
0,
));
}
}
}
bail!(format!("Failed to connect to {remote_addr}"));
}
pub async fn connect<'t, T>(
target: T,
local_addr: Option<SocketAddr>,
proxy_conf: &Socks5Server,
ms_timeout: u64,
) -> ResultType<Self>
where
T: IntoTargetAddr<'t>,
{
let proxy = Proxy::from_conf(proxy_conf, Some(ms_timeout))?;
proxy.connect::<T>(target, local_addr).await
}
pub fn local_addr(&self) -> SocketAddr {
self.1
}
pub fn set_send_timeout(&mut self, ms: u64) {
self.3 = ms;
}
pub fn from(stream: impl TcpStreamTrait + Send + Sync + 'static, addr: SocketAddr) -> Self {
Self(
Framed::new(DynTcpStream(Box::new(stream)), BytesCodec::new()),
addr,
None,
0,
)
}
pub fn set_raw(&mut self) {
self.0.codec_mut().set_raw();
self.2 = None;
}
pub fn is_secured(&self) -> bool {
self.2.is_some()
}
#[inline]
pub async fn send(&mut self, msg: &impl Message) -> ResultType<()> {
self.send_raw(msg.write_to_bytes()?).await
}
#[inline]
pub async fn send_raw(&mut self, msg: Vec<u8>) -> ResultType<()> {
let mut msg = msg;
if let Some(key) = self.2.as_mut() {
msg = key.enc(&msg);
}
self.send_bytes(bytes::Bytes::from(msg)).await?;
Ok(())
}
#[inline]
pub async fn send_bytes(&mut self, bytes: Bytes) -> ResultType<()> {
if self.3 > 0 {
super::timeout(self.3, self.0.send(bytes)).await??;
} else {
self.0.send(bytes).await?;
}
Ok(())
}
#[inline]
pub async fn next(&mut self) -> Option<Result<BytesMut, Error>> {
let mut res = self.0.next().await;
if let Some(Ok(bytes)) = res.as_mut() {
if let Some(key) = self.2.as_mut() {
if let Err(err) = key.dec(bytes) {
return Some(Err(err));
}
}
}
res
}
#[inline]
pub async fn next_timeout(&mut self, ms: u64) -> Option<Result<BytesMut, Error>> {
if let Ok(res) = super::timeout(ms, self.next()).await {
res
} else {
None
}
}
pub fn set_key(&mut self, key: Key) {
self.2 = Some(Encrypt::new(key));
}
fn get_nonce(seqnum: u64) -> Nonce {
let mut nonce = Nonce([0u8; secretbox::NONCEBYTES]);
nonce.0[..std::mem::size_of_val(&seqnum)].copy_from_slice(&seqnum.to_le_bytes());
nonce
}
}
const DEFAULT_BACKLOG: u32 = 128;
pub async fn new_listener<T: ToSocketAddrs>(addr: T, reuse: bool) -> ResultType<TcpListener> {
if !reuse {
Ok(TcpListener::bind(addr).await?)
} else {
let addr = lookup_host(&addr)
.await?
.next()
.context("could not resolve to any address")?;
new_socket(addr, true)?
.listen(DEFAULT_BACKLOG)
.map_err(anyhow::Error::msg)
}
}
pub async fn listen_any(port: u16) -> ResultType<TcpListener> {
if let Ok(mut socket) = TcpSocket::new_v6() {
#[cfg(unix)]
{
// illumos has no support for SO_REUSEPORT
#[cfg(not(target_os = "illumos"))]
socket.set_reuseport(true).ok();
socket.set_reuseaddr(true).ok();
use std::os::unix::io::{FromRawFd, IntoRawFd};
let raw_fd = socket.into_raw_fd();
let sock2 = unsafe { socket2::Socket::from_raw_fd(raw_fd) };
sock2.set_only_v6(false).ok();
socket = unsafe { TcpSocket::from_raw_fd(sock2.into_raw_fd()) };
}
#[cfg(windows)]
{
use std::os::windows::prelude::{FromRawSocket, IntoRawSocket};
let raw_socket = socket.into_raw_socket();
let sock2 = unsafe { socket2::Socket::from_raw_socket(raw_socket) };
sock2.set_only_v6(false).ok();
socket = unsafe { TcpSocket::from_raw_socket(sock2.into_raw_socket()) };
}
if socket
.bind(SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), port))
.is_ok()
{
if let Ok(l) = socket.listen(DEFAULT_BACKLOG) {
return Ok(l);
}
}
}
Ok(new_socket(
SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), port),
true,
)?
.listen(DEFAULT_BACKLOG)?)
}
impl Unpin for DynTcpStream {}
impl AsyncRead for DynTcpStream {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
AsyncRead::poll_read(Pin::new(&mut self.0), cx, buf)
}
}
impl AsyncWrite for DynTcpStream {
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
AsyncWrite::poll_write(Pin::new(&mut self.0), cx, buf)
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
AsyncWrite::poll_flush(Pin::new(&mut self.0), cx)
}
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
AsyncWrite::poll_shutdown(Pin::new(&mut self.0), cx)
}
}
impl<R: AsyncRead + AsyncWrite + Unpin> TcpStreamTrait for R {}
impl Encrypt {
pub fn new(key: Key) -> Self {
Self(key, 0, 0)
}
pub fn dec(&mut self, bytes: &mut BytesMut) -> Result<(), Error> {
if bytes.len() <= 1 {
return Ok(());
}
self.2 += 1;
let nonce = FramedStream::get_nonce(self.2);
match secretbox::open(bytes, &nonce, &self.0) {
Ok(res) => {
bytes.clear();
bytes.put_slice(&res);
Ok(())
}
Err(()) => Err(Error::new(ErrorKind::Other, "decryption error")),
}
}
pub fn enc(&mut self, data: &[u8]) -> Vec<u8> {
self.1 += 1;
let nonce = FramedStream::get_nonce(self.1);
secretbox::seal(&data, &nonce, &self.0)
}
pub fn decode(
symmetric_data: &[u8],
their_pk_b: &[u8],
our_sk_b: &box_::SecretKey,
) -> ResultType<Key> {
if their_pk_b.len() != box_::PUBLICKEYBYTES {
anyhow::bail!("Handshake failed: pk length {}", their_pk_b.len());
}
let nonce = box_::Nonce([0u8; box_::NONCEBYTES]);
let mut pk_ = [0u8; box_::PUBLICKEYBYTES];
pk_[..].copy_from_slice(their_pk_b);
let their_pk_b = box_::PublicKey(pk_);
let symmetric_key = box_::open(symmetric_data, &nonce, &their_pk_b, &our_sk_b)
.map_err(|_| anyhow::anyhow!("Handshake failed: box decryption failure"))?;
if symmetric_key.len() != secretbox::KEYBYTES {
anyhow::bail!("Handshake failed: invalid secret key length from peer");
}
let mut key = [0u8; secretbox::KEYBYTES];
key[..].copy_from_slice(&symmetric_key);
Ok(Key(key))
}
}
+121
View File
@@ -0,0 +1,121 @@
use std::{collections::HashMap, sync::RwLock};
use crate::config::allow_insecure_tls_fallback;
#[derive(Debug, Clone, Copy)]
pub enum TlsType {
Plain,
NativeTls,
Rustls,
}
lazy_static::lazy_static! {
static ref URL_TLS_TYPE: RwLock<HashMap<String, TlsType>> = RwLock::new(HashMap::new());
static ref URL_TLS_DANGER_ACCEPT_INVALID_CERTS: RwLock<HashMap<String, bool>> = RwLock::new(HashMap::new());
}
#[inline]
pub fn is_plain(url: &str) -> bool {
url.starts_with("ws://") || url.starts_with("http://")
}
// Extract domain from URL.
// e.g., "https://example.com/path" -> "example.com"
// "https://example.com:8080/path" -> "example.com:8080"
// See the tests for more examples.
#[inline]
fn get_domain_and_port_from_url(url: &str) -> &str {
// Remove scheme (e.g., http://, https://, ws://, wss://)
let scheme_end = url.find("://").map(|pos| pos + 3).unwrap_or(0);
let url2 = &url[scheme_end..];
// If userinfo is present, domain is after last '@'
let after_at = match url2.rfind('@') {
Some(pos) => &url2[pos + 1..],
None => url2,
};
// Find the end of domain (before '/' or '?')
let domain_end = after_at.find(&['/', '?'][..]).unwrap_or(after_at.len());
&after_at[..domain_end]
}
#[inline]
pub fn upsert_tls_cache(url: &str, tls_type: TlsType, danger_accept_invalid_cert: bool) {
if is_plain(url) {
return;
}
let domain_port = get_domain_and_port_from_url(url);
// Use curly braces to ensure the lock is released immediately.
{
URL_TLS_TYPE
.write()
.unwrap()
.insert(domain_port.to_string(), tls_type);
}
{
URL_TLS_DANGER_ACCEPT_INVALID_CERTS
.write()
.unwrap()
.insert(domain_port.to_string(), danger_accept_invalid_cert);
}
}
#[inline]
pub fn reset_tls_cache() {
// Use curly braces to ensure the lock is released immediately.
{
URL_TLS_TYPE.write().unwrap().clear();
}
{
URL_TLS_DANGER_ACCEPT_INVALID_CERTS.write().unwrap().clear();
}
}
#[inline]
pub fn get_cached_tls_type(url: &str) -> Option<TlsType> {
if is_plain(url) {
return Some(TlsType::Plain);
}
let domain_port = get_domain_and_port_from_url(url);
URL_TLS_TYPE.read().unwrap().get(domain_port).cloned()
}
#[inline]
pub fn get_cached_tls_accept_invalid_cert(url: &str) -> Option<bool> {
if !allow_insecure_tls_fallback() {
return Some(false);
}
if is_plain(url) {
return Some(false);
}
let domain_port = get_domain_and_port_from_url(url);
URL_TLS_DANGER_ACCEPT_INVALID_CERTS
.read()
.unwrap()
.get(domain_port)
.cloned()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_get_domain_and_port_from_url() {
for (url, expected_domain_port) in vec![
("http://example.com", "example.com"),
("https://example.com", "example.com"),
("ws://example.com/path", "example.com"),
("wss://example.com:8080/path", "example.com:8080"),
("https://user:pass@example.com", "example.com"),
("https://example.com?query=param", "example.com"),
("https://example.com:8443?query=param", "example.com:8443"),
("ftp://example.com/resource", "example.com"), // ftp scheme
("example.com/path", "example.com"), // no scheme
("example.com:8080/path", "example.com:8080"),
] {
let domain_port = get_domain_and_port_from_url(url);
assert_eq!(domain_port, expected_domain_port);
}
}
}
+171
View File
@@ -0,0 +1,171 @@
use crate::ResultType;
use anyhow::{anyhow, Context};
use bytes::{Bytes, BytesMut};
use futures::{SinkExt, StreamExt};
use protobuf::Message;
use socket2::{Domain, Socket, Type};
use std::net::SocketAddr;
use tokio::net::{lookup_host, ToSocketAddrs, UdpSocket};
use tokio_socks::{udp::Socks5UdpFramed, IntoTargetAddr, TargetAddr, ToProxyAddrs};
use tokio_util::{codec::BytesCodec, udp::UdpFramed};
pub enum FramedSocket {
Direct(UdpFramed<BytesCodec>),
ProxySocks(Socks5UdpFramed),
}
fn new_socket(addr: SocketAddr, reuse: bool, buf_size: usize) -> Result<Socket, std::io::Error> {
let socket = match addr {
SocketAddr::V4(..) => Socket::new(Domain::ipv4(), Type::dgram(), None),
SocketAddr::V6(..) => Socket::new(Domain::ipv6(), Type::dgram(), None),
}?;
if reuse {
// windows has no reuse_port, but its reuse_address
// almost equals to unix's reuse_port + reuse_address,
// though may introduce nondeterministic behavior
// illumos has no support for SO_REUSEPORT
#[cfg(all(unix, not(target_os = "illumos")))]
socket.set_reuse_port(true).ok();
socket.set_reuse_address(true).ok();
}
// only nonblocking work with tokio, https://stackoverflow.com/questions/64649405/receiver-on-tokiompscchannel-only-receives-messages-when-buffer-is-full
socket.set_nonblocking(true)?;
if buf_size > 0 {
socket.set_recv_buffer_size(buf_size).ok();
}
log::debug!(
"Receive buf size of udp {}: {:?}",
addr,
socket.recv_buffer_size()
);
if addr.is_ipv6() && addr.ip().is_unspecified() && addr.port() > 0 {
socket.set_only_v6(false).ok();
}
socket.bind(&addr.into())?;
Ok(socket)
}
impl FramedSocket {
pub async fn new<T: ToSocketAddrs>(addr: T) -> ResultType<Self> {
Self::new_reuse(addr, false, 0).await
}
pub async fn new_reuse<T: ToSocketAddrs>(
addr: T,
reuse: bool,
buf_size: usize,
) -> ResultType<Self> {
let addr = lookup_host(&addr)
.await?
.next()
.context("could not resolve to any address")?;
Ok(Self::Direct(UdpFramed::new(
UdpSocket::from_std(new_socket(addr, reuse, buf_size)?.into_udp_socket())?,
BytesCodec::new(),
)))
}
pub async fn new_proxy<'a, 't, P: ToProxyAddrs, T: ToSocketAddrs>(
proxy: P,
local: T,
username: &'a str,
password: &'a str,
ms_timeout: u64,
) -> ResultType<Self> {
let framed = if username.trim().is_empty() {
super::timeout(ms_timeout, Socks5UdpFramed::connect(proxy, Some(local))).await??
} else {
super::timeout(
ms_timeout,
Socks5UdpFramed::connect_with_password(proxy, Some(local), username, password),
)
.await??
};
log::trace!(
"Socks5 udp connected, local addr: {:?}, target addr: {}",
framed.local_addr(),
framed.socks_addr()
);
Ok(Self::ProxySocks(framed))
}
#[inline]
pub async fn send(
&mut self,
msg: &impl Message,
addr: impl IntoTargetAddr<'_>,
) -> ResultType<()> {
let addr = addr.into_target_addr()?.to_owned();
let send_data = Bytes::from(msg.write_to_bytes()?);
match self {
Self::Direct(f) => {
if let TargetAddr::Ip(addr) = addr {
f.send((send_data, addr)).await?
}
}
Self::ProxySocks(f) => f.send((send_data, addr)).await?,
};
Ok(())
}
// https://stackoverflow.com/a/68733302/1926020
#[inline]
pub async fn send_raw(
&mut self,
msg: &'static [u8],
addr: impl IntoTargetAddr<'static>,
) -> ResultType<()> {
let addr = addr.into_target_addr()?.to_owned();
match self {
Self::Direct(f) => {
if let TargetAddr::Ip(addr) = addr {
f.send((Bytes::from(msg), addr)).await?
}
}
Self::ProxySocks(f) => f.send((Bytes::from(msg), addr)).await?,
};
Ok(())
}
#[inline]
pub async fn next(&mut self) -> Option<ResultType<(BytesMut, TargetAddr<'static>)>> {
match self {
Self::Direct(f) => match f.next().await {
Some(Ok((data, addr))) => {
Some(Ok((data, addr.into_target_addr().ok()?.to_owned())))
}
Some(Err(e)) => Some(Err(anyhow!(e))),
None => None,
},
Self::ProxySocks(f) => match f.next().await {
Some(Ok((data, _))) => Some(Ok((data.data, data.dst_addr))),
Some(Err(e)) => Some(Err(anyhow!(e))),
None => None,
},
}
}
#[inline]
pub async fn next_timeout(
&mut self,
ms: u64,
) -> Option<ResultType<(BytesMut, TargetAddr<'static>)>> {
if let Ok(res) =
tokio::time::timeout(std::time::Duration::from_millis(ms), self.next()).await
{
res
} else {
None
}
}
pub fn local_addr(&self) -> Option<SocketAddr> {
if let FramedSocket::Direct(x) = self {
if let Ok(v) = x.get_ref().local_addr() {
return Some(v);
}
}
None
}
}
+257
View File
@@ -0,0 +1,257 @@
use crate::ResultType;
use rustls_pki_types::{ServerName, UnixTime};
use std::sync::Arc;
use tokio_rustls::rustls::{self, client::WebPkiServerVerifier, ClientConfig};
use tokio_rustls::rustls::{
client::danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier},
DigitallySignedStruct, Error as TLSError, SignatureScheme,
};
// https://github.com/seanmonstar/reqwest/blob/fd61bc93e6f936454ce0b978c6f282f06eee9287/src/tls.rs#L608
#[derive(Debug)]
pub(crate) struct NoVerifier;
impl ServerCertVerifier for NoVerifier {
fn verify_server_cert(
&self,
_end_entity: &rustls_pki_types::CertificateDer,
_intermediates: &[rustls_pki_types::CertificateDer],
_server_name: &ServerName,
_ocsp_response: &[u8],
_now: UnixTime,
) -> Result<ServerCertVerified, TLSError> {
Ok(ServerCertVerified::assertion())
}
fn verify_tls12_signature(
&self,
_message: &[u8],
_cert: &rustls_pki_types::CertificateDer,
_dss: &DigitallySignedStruct,
) -> Result<HandshakeSignatureValid, TLSError> {
Ok(HandshakeSignatureValid::assertion())
}
fn verify_tls13_signature(
&self,
_message: &[u8],
_cert: &rustls_pki_types::CertificateDer,
_dss: &DigitallySignedStruct,
) -> Result<HandshakeSignatureValid, TLSError> {
Ok(HandshakeSignatureValid::assertion())
}
fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
vec![
SignatureScheme::RSA_PKCS1_SHA1,
SignatureScheme::ECDSA_SHA1_Legacy,
SignatureScheme::RSA_PKCS1_SHA256,
SignatureScheme::ECDSA_NISTP256_SHA256,
SignatureScheme::RSA_PKCS1_SHA384,
SignatureScheme::ECDSA_NISTP384_SHA384,
SignatureScheme::RSA_PKCS1_SHA512,
SignatureScheme::ECDSA_NISTP521_SHA512,
SignatureScheme::RSA_PSS_SHA256,
SignatureScheme::RSA_PSS_SHA384,
SignatureScheme::RSA_PSS_SHA512,
SignatureScheme::ED25519,
SignatureScheme::ED448,
]
}
}
/// A certificate verifier that tries a primary verifier first,
/// and falls back to a platform verifier if the primary fails.
#[cfg(any(target_os = "android", target_os = "ios"))]
#[derive(Debug)]
struct FallbackPlatformVerifier {
primary: Arc<dyn ServerCertVerifier>,
fallback: Arc<dyn ServerCertVerifier>,
}
#[cfg(any(target_os = "android", target_os = "ios"))]
impl FallbackPlatformVerifier {
fn with_platform_fallback(
primary: Arc<dyn ServerCertVerifier>,
provider: Arc<rustls::crypto::CryptoProvider>,
) -> Result<Self, TLSError> {
#[cfg(target_os = "android")]
if !crate::config::ANDROID_RUSTLS_PLATFORM_VERIFIER_INITIALIZED
.load(std::sync::atomic::Ordering::Relaxed)
{
return Err(TLSError::General(
"rustls-platform-verifier not initialized".to_string(),
));
}
let fallback = Arc::new(rustls_platform_verifier::Verifier::new(provider)?);
Ok(Self { primary, fallback })
}
}
#[cfg(any(target_os = "android", target_os = "ios"))]
impl ServerCertVerifier for FallbackPlatformVerifier {
fn verify_server_cert(
&self,
end_entity: &rustls_pki_types::CertificateDer<'_>,
intermediates: &[rustls_pki_types::CertificateDer<'_>],
server_name: &ServerName<'_>,
ocsp_response: &[u8],
now: UnixTime,
) -> Result<ServerCertVerified, TLSError> {
match self.primary.verify_server_cert(
end_entity,
intermediates,
server_name,
ocsp_response,
now,
) {
Ok(verified) => Ok(verified),
Err(primary_err) => {
match self.fallback.verify_server_cert(
end_entity,
intermediates,
server_name,
ocsp_response,
now,
) {
Ok(verified) => Ok(verified),
Err(fallback_err) => {
log::error!(
"Both primary and fallback verifiers failed to verify server certificate, primary error: {:?}, fallback error: {:?}",
primary_err,
fallback_err
);
Err(primary_err)
}
}
}
}
}
fn verify_tls12_signature(
&self,
message: &[u8],
cert: &rustls_pki_types::CertificateDer<'_>,
dss: &DigitallySignedStruct,
) -> Result<HandshakeSignatureValid, TLSError> {
// Both WebPkiServerVerifier and rustls_platform_verifier use the same signature verification implementation.
// https://github.com/rustls/rustls/blob/1ee126adb3352a2dcd72420dcd6040351a6ddc1e/rustls/src/webpki/server_verifier.rs#L278
// https://github.com/rustls/rustls/blob/1ee126adb3352a2dcd72420dcd6040351a6ddc1e/rustls/src/crypto/mod.rs#L17
// https://github.com/rustls/rustls-platform-verifier/blob/1099f161bfc5e3ac7f90aad88b1bf788e72906cb/rustls-platform-verifier/src/verification/android.rs#L9
// https://github.com/rustls/rustls-platform-verifier/blob/1099f161bfc5e3ac7f90aad88b1bf788e72906cb/rustls-platform-verifier/src/verification/apple.rs#L6
self.primary.verify_tls12_signature(message, cert, dss)
}
fn verify_tls13_signature(
&self,
message: &[u8],
cert: &rustls_pki_types::CertificateDer<'_>,
dss: &DigitallySignedStruct,
) -> Result<HandshakeSignatureValid, TLSError> {
// Same implementation as verify_tls12_signature.
self.primary.verify_tls13_signature(message, cert, dss)
}
fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
// Both WebPkiServerVerifier and rustls_platform_verifier use the same crypto provider,
// so their supported signature schemes are identical.
// https://github.com/rustls/rustls/blob/1ee126adb3352a2dcd72420dcd6040351a6ddc1e/rustls/src/webpki/server_verifier.rs#L172C52-L172C85
// https://github.com/rustls/rustls-platform-verifier/blob/1099f161bfc5e3ac7f90aad88b1bf788e72906cb/rustls-platform-verifier/src/verification/android.rs#L327
// https://github.com/rustls/rustls-platform-verifier/blob/1099f161bfc5e3ac7f90aad88b1bf788e72906cb/rustls-platform-verifier/src/verification/apple.rs#L304
self.primary.supported_verify_schemes()
}
}
fn webpki_server_verifier(
provider: Arc<rustls::crypto::CryptoProvider>,
) -> ResultType<Arc<WebPkiServerVerifier>> {
// Load root certificates from both bundled webpki_roots and system-native certificate stores.
// This approach is consistent with how reqwest and tokio-tungstenite handle root certificates.
// https://github.com/snapview/tokio-tungstenite/blob/35d110c24c9d030d1608ec964d70c789dfb27452/src/tls.rs#L95
// https://github.com/seanmonstar/reqwest/blob/b126ca49da7897e5d676639cdbf67a0f6838b586/src/async_impl/client.rs#L643
let mut root_cert_store = rustls::RootCertStore::empty();
root_cert_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
let rustls_native_certs::CertificateResult { certs, errors, .. } =
rustls_native_certs::load_native_certs();
if !errors.is_empty() {
log::warn!("native root CA certificate loading errors: {errors:?}");
}
root_cert_store.add_parsable_certificates(certs);
// Build verifier using with_root_certificates behavior (WebPkiServerVerifier without CRLs).
// Both reqwest and tokio-tungstenite use this approach.
// https://github.com/seanmonstar/reqwest/blob/b126ca49da7897e5d676639cdbf67a0f6838b586/src/async_impl/client.rs#L749
// https://github.com/snapview/tokio-tungstenite/blob/35d110c24c9d030d1608ec964d70c789dfb27452/src/tls.rs#L127
// https://github.com/rustls/rustls/blob/1ee126adb3352a2dcd72420dcd6040351a6ddc1e/rustls/src/client/builder.rs#L47
// with_root_certificates creates a WebPkiServerVerifier without revocation checking:
// https://github.com/rustls/rustls/blob/1ee126adb3352a2dcd72420dcd6040351a6ddc1e/rustls/src/webpki/server_verifier.rs#L177
// https://github.com/rustls/rustls/blob/1ee126adb3352a2dcd72420dcd6040351a6ddc1e/rustls/src/webpki/server_verifier.rs#L168
// Since no CRL is provided (as is the case here), we must explicitly set allow_unknown_revocation_status()
// to match the behavior of with_root_certificates, which allows unknown revocation status by default.
// https://github.com/rustls/rustls/blob/1ee126adb3352a2dcd72420dcd6040351a6ddc1e/rustls/src/webpki/server_verifier.rs#L37
// Note: build() only returns an error if the root certificate store is empty, which won't happen here.
let verifier = rustls::client::WebPkiServerVerifier::builder_with_provider(
Arc::new(root_cert_store),
provider.clone(),
)
.allow_unknown_revocation_status()
.build()
.map_err(|e| anyhow::anyhow!(e))?;
Ok(verifier)
}
pub fn client_config(danger_accept_invalid_cert: bool) -> ResultType<ClientConfig> {
if danger_accept_invalid_cert {
client_config_danger()
} else {
client_config_safe()
}
}
pub fn client_config_safe() -> ResultType<ClientConfig> {
// Use the default builder which uses the default protocol versions and crypto provider.
// The with_protocol_versions API has been removed in rustls master branch:
// https://github.com/rustls/rustls/pull/2599
// This approach is consistent with tokio-tungstenite's usage:
// https://github.com/snapview/tokio-tungstenite/blob/35d110c24c9d030d1608ec964d70c789dfb27452/src/tls.rs#L126
let config_builder = rustls::ClientConfig::builder();
let provider = config_builder.crypto_provider().clone();
let webpki_verifier = webpki_server_verifier(provider.clone())?;
#[cfg(any(target_os = "android", target_os = "ios"))]
{
match FallbackPlatformVerifier::with_platform_fallback(webpki_verifier.clone(), provider) {
Ok(fallback_verifier) => {
let config = config_builder
.dangerous()
.with_custom_certificate_verifier(Arc::new(fallback_verifier))
.with_no_client_auth();
Ok(config)
}
Err(e) => {
log::error!(
"Failed to create fallback verifier: {:?}, use webpki verifier instead",
e
);
let config = config_builder
.with_webpki_verifier(webpki_verifier)
.with_no_client_auth();
Ok(config)
}
}
}
#[cfg(not(any(target_os = "android", target_os = "ios")))]
{
let config = config_builder
.with_webpki_verifier(webpki_verifier)
.with_no_client_auth();
Ok(config)
}
}
pub fn client_config_danger() -> ResultType<ClientConfig> {
let config = ClientConfig::builder()
.dangerous()
.with_custom_certificate_verifier(Arc::new(NoVerifier))
.with_no_client_auth();
Ok(config)
}
+770
View File
@@ -0,0 +1,770 @@
use std::collections::HashMap;
use std::io::{Error, ErrorKind};
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use std::sync::Arc;
use std::time::Duration;
use webrtc::api::setting_engine::SettingEngine;
use webrtc::api::APIBuilder;
use webrtc::data_channel::RTCDataChannel;
use webrtc::ice::mdns::MulticastDnsMode;
use webrtc::ice_transport::ice_server::RTCIceServer;
use webrtc::peer_connection::configuration::RTCConfiguration;
use webrtc::peer_connection::peer_connection_state::RTCPeerConnectionState;
use webrtc::peer_connection::policy::ice_transport_policy::RTCIceTransportPolicy;
use webrtc::peer_connection::sdp::session_description::RTCSessionDescription;
use webrtc::peer_connection::RTCPeerConnection;
use base64::engine::general_purpose::STANDARD as BASE64_STANDARD;
use base64::Engine;
use bytes::{Bytes, BytesMut};
use tokio::sync::watch;
use tokio::sync::Mutex;
use tokio::time::timeout;
use url::Url;
use crate::config;
use crate::protobuf::Message;
use crate::sodiumoxide::crypto::secretbox::Key;
use crate::ResultType;
pub struct WebRTCStream {
pc: Arc<RTCPeerConnection>,
stream: Arc<Mutex<Arc<RTCDataChannel>>>,
state_notify: watch::Receiver<bool>,
send_timeout: u64,
}
/// Standard maximum message size for WebRTC data channels (RFC 8831, 65535 bytes).
/// Most browsers, including Chromium, enforce this protocol limit.
const DATA_CHANNEL_BUFFER_SIZE: u16 = u16::MAX;
// use 3 public STUN servers to find out the NAT type, 2 must be the same address but different ports
// https://stackoverflow.com/questions/72805316/determine-nat-mapping-behaviour-using-two-stun-servers
// luckily nextcloud supports two ports for STUN
// unluckily webrtc-rs does not use the same port to do the STUN request
static DEFAULT_ICE_SERVERS: [&str; 3] = [
"stun:stun.cloudflare.com:3478",
"stun:stun.nextcloud.com:3478",
"stun:stun.nextcloud.com:443",
];
lazy_static::lazy_static! {
static ref SESSIONS: Arc::<Mutex<HashMap<String, WebRTCStream>>> = Default::default();
}
impl Clone for WebRTCStream {
fn clone(&self) -> Self {
WebRTCStream {
pc: self.pc.clone(),
stream: self.stream.clone(),
state_notify: self.state_notify.clone(),
send_timeout: self.send_timeout,
}
}
}
impl WebRTCStream {
#[inline]
fn get_remote_offer(endpoint: &str) -> ResultType<String> {
// Ensure the endpoint starts with the "webrtc://" prefix
if !endpoint.starts_with("webrtc://") {
return Err(
Error::new(ErrorKind::InvalidInput, "Invalid WebRTC endpoint format").into(),
);
}
// Extract the Base64-encoded SDP part
let encoded_sdp = &endpoint["webrtc://".len()..];
// Decode the Base64 string
let decoded_bytes = BASE64_STANDARD
.decode(encoded_sdp)
.map_err(|_| Error::new(ErrorKind::InvalidInput, "Failed to decode Base64 SDP"))?;
Ok(String::from_utf8(decoded_bytes).map_err(|_| {
Error::new(
ErrorKind::InvalidInput,
"Failed to convert decoded bytes to UTF-8",
)
})?)
}
#[inline]
fn sdp_to_endpoint(sdp: &str) -> String {
let encoded_sdp = BASE64_STANDARD.encode(sdp);
format!("webrtc://{}", encoded_sdp)
}
#[inline]
fn get_key_for_sdp(sdp: &RTCSessionDescription) -> ResultType<String> {
let binding = sdp.unmarshal()?;
let Some(fingerprint) = binding.attribute("fingerprint") else {
// find fingerprint attribute in media descriptions
for media in &binding.media_descriptions {
if media.media_name.media != "application" {
continue;
}
if let Some(fp) = media
.attributes
.iter()
.find(|x| x.key == "fingerprint")
.and_then(|x| x.value.clone())
{
return Ok(fp);
}
}
return Err(anyhow::anyhow!("SDP fingerprint attribute not found"));
};
Ok(fingerprint.to_string())
}
#[inline]
fn get_key_for_sdp_json(sdp_json: &str) -> ResultType<String> {
if sdp_json.is_empty() {
return Ok("".to_string());
}
let sdp = serde_json::from_str::<RTCSessionDescription>(&sdp_json)?;
Self::get_key_for_sdp(&sdp)
}
#[inline]
async fn get_key_for_peer(pc: &Arc<RTCPeerConnection>, is_local: bool) -> ResultType<String> {
let Some(desc) = (match is_local {
true => pc.local_description().await,
false => pc.remote_description().await,
}) else {
return Err(anyhow::anyhow!("PeerConnection description is not set"));
};
Self::get_key_for_sdp(&desc)
}
#[inline]
fn get_ice_server_from_url(url: &str) -> Option<RTCIceServer> {
// standard url format with turn scheme: turn://user:pass@host:port
match Url::parse(url) {
Ok(u) => {
if u.scheme() == "turn"
|| u.scheme() == "turns"
|| u.scheme() == "stun"
|| u.scheme() == "stuns"
{
Some(RTCIceServer {
urls: vec![format!(
"{}:{}:{}",
u.scheme(),
u.host_str().unwrap_or_default(),
u.port().unwrap_or(3478)
)],
username: u.username().to_string(),
credential: u.password().unwrap_or_default().to_string(),
..Default::default()
})
} else {
None
}
}
Err(_) => None,
}
}
#[inline]
fn get_ice_servers() -> Vec<RTCIceServer> {
let mut ice_servers = Vec::new();
let cfg = config::Config::get_option(config::keys::OPTION_ICE_SERVERS);
let mut has_stun = false;
for url in cfg.split(',').map(str::trim) {
if let Some(ice_server) = Self::get_ice_server_from_url(url) {
// Detect STUN in user config
if ice_server
.urls
.iter()
.any(|u| u.starts_with("stun:") || u.starts_with("stuns:"))
{
has_stun = true;
}
ice_servers.push(ice_server);
}
}
// If there is no STUN (either TURN-only or empty config) → prepend defaults
if !has_stun {
ice_servers.insert(
0,
RTCIceServer {
urls: DEFAULT_ICE_SERVERS.iter().map(|s| s.to_string()).collect(),
..Default::default()
},
);
}
ice_servers
}
pub async fn new(
remote_endpoint: &str,
force_relay: bool,
ms_timeout: u64,
) -> ResultType<Self> {
log::debug!("New webrtc stream to endpoint: {}", remote_endpoint);
let remote_offer = if remote_endpoint.is_empty() {
"".into()
} else {
Self::get_remote_offer(remote_endpoint)?
};
let mut key = Self::get_key_for_sdp_json(&remote_offer)?;
let sessions_lock = SESSIONS.lock().await;
if let Some(cached_stream) = sessions_lock.get(&key) {
if !key.is_empty() {
log::debug!("Start webrtc with cached peer");
return Ok(cached_stream.clone());
}
}
drop(sessions_lock);
let start_local_offer = remote_offer.is_empty();
// Create a SettingEngine and enable Detach
let mut s = SettingEngine::default();
s.detach_data_channels();
s.set_ice_multicast_dns_mode(MulticastDnsMode::Disabled);
// Create the API object
let api = APIBuilder::new().with_setting_engine(s).build();
// Prepare the configuration, get ICE servers from config
let config = RTCConfiguration {
ice_servers: Self::get_ice_servers(),
ice_transport_policy: if force_relay {
RTCIceTransportPolicy::Relay
} else {
RTCIceTransportPolicy::All
},
..Default::default()
};
let (notify_tx, notify_rx) = watch::channel(false);
// Create a new RTCPeerConnection
let pc = Arc::new(api.new_peer_connection(config).await?);
let bootstrap_dc = if start_local_offer {
let dc_open_notify = notify_tx.clone();
// Create a data channel with label "bootstrap"
let dc = pc.create_data_channel("bootstrap", None).await?;
dc.on_open(Box::new(move || {
log::debug!("Local data channel bootstrap open.");
let _ = dc_open_notify.send(true);
Box::pin(async {})
}));
dc
} else {
// Wait for the data channel to be created by the remote peer
// Here we create a dummy data channel to satisfy the type system
Arc::new(RTCDataChannel::default())
};
let stream = Arc::new(Mutex::new(bootstrap_dc));
if !start_local_offer {
// Register data channel creation handling
let dc_open_notify = notify_tx.clone();
let stream_for_dc = stream.clone();
pc.on_data_channel(Box::new(move |dc: Arc<RTCDataChannel>| {
let d_label = dc.label().to_owned();
let dc_open_notify2 = dc_open_notify.clone();
let stream_for_dc_clone = stream_for_dc.clone();
log::debug!("Remote data channel {} ready", d_label);
Box::pin(async move {
let mut stream_lock = stream_for_dc_clone.lock().await;
*stream_lock = dc.clone();
drop(stream_lock);
dc.on_open(Box::new(move || {
let _ = dc_open_notify2.send(true);
Box::pin(async {})
}));
})
}));
}
// This will notify you when the peer has connected/disconnected
let stream_for_close = stream.clone();
let pc_for_close = pc.clone();
pc.on_peer_connection_state_change(Box::new(move |s: RTCPeerConnectionState| {
let stream_for_close2 = stream_for_close.clone();
let on_connection_notify = notify_tx.clone();
let pc_for_close2 = pc_for_close.clone();
Box::pin(async move {
log::debug!("WebRTC session peer connection state: {}", s);
match s {
RTCPeerConnectionState::Disconnected
| RTCPeerConnectionState::Failed
| RTCPeerConnectionState::Closed => {
let _ = on_connection_notify.send(true);
log::debug!("WebRTC session closing due to disconnected");
let _ = stream_for_close2.lock().await.close().await;
log::debug!("WebRTC session stream closed");
let mut sessions_lock = SESSIONS.lock().await;
match Self::get_key_for_peer(&pc_for_close2, start_local_offer).await {
Ok(k) => {
sessions_lock.remove(&k);
log::debug!("WebRTC session removed key: {}", k);
}
Err(e) => {
log::error!(
"Failed to extract key for peer during session cleanup: {:?}",
e
);
// Fallback: try to remove any session associated with this peer connection
let keys_to_remove: Vec<String> = sessions_lock
.iter()
.filter_map(|(key, session)| {
if Arc::ptr_eq(&session.pc, &pc_for_close2) {
Some(key.clone())
} else {
None
}
})
.collect();
for k in keys_to_remove {
sessions_lock.remove(&k);
log::debug!("WebRTC session removed by fallback key: {}", k);
}
}
}
}
_ => {}
}
})
}));
// process offer/answer
if start_local_offer {
let sdp = pc.create_offer(None).await?;
let mut gather_complete = pc.gathering_complete_promise().await;
pc.set_local_description(sdp.clone()).await?;
let _ = gather_complete.recv().await;
log::debug!("local offer:\n{}", sdp.sdp);
// get local sdp key
key = Self::get_key_for_sdp(&sdp)?;
log::debug!("Start webrtc with local key: {}", key);
} else {
let sdp = serde_json::from_str::<RTCSessionDescription>(&remote_offer)?;
pc.set_remote_description(sdp.clone()).await?;
let answer = pc.create_answer(None).await?;
let mut gather_complete = pc.gathering_complete_promise().await;
pc.set_local_description(answer).await?;
let _ = gather_complete.recv().await;
log::debug!("remote offer:\n{}", sdp.sdp);
// get remote sdp key
key = Self::get_key_for_sdp(&sdp)?;
log::debug!("Start webrtc with remote key: {}", key);
}
let mut final_lock = SESSIONS.lock().await;
if let Some(session) = final_lock.get(&key) {
pc.close().await.ok();
return Ok(session.clone());
}
let webrtc_stream = Self {
pc,
stream,
state_notify: notify_rx,
send_timeout: ms_timeout,
};
final_lock.insert(key, webrtc_stream.clone());
Ok(webrtc_stream)
}
#[inline]
pub async fn get_local_endpoint(&self) -> ResultType<String> {
if let Some(local_desc) = self.pc.local_description().await {
let sdp = serde_json::to_string(&local_desc)?;
let endpoint = Self::sdp_to_endpoint(&sdp);
Ok(endpoint)
} else {
Err(anyhow::anyhow!("Local desc is not set"))
}
}
#[inline]
pub async fn set_remote_endpoint(&self, endpoint: &str) -> ResultType<()> {
let offer = Self::get_remote_offer(endpoint)?;
log::debug!("WebRTC set remote sdp: {}", offer);
let sdp = serde_json::from_str::<RTCSessionDescription>(&offer)?;
self.pc.set_remote_description(sdp).await?;
Ok(())
}
#[inline]
pub fn set_raw(&mut self) {
// not-supported
}
#[inline]
pub fn local_addr(&self) -> SocketAddr {
SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0)
}
#[inline]
pub fn set_send_timeout(&mut self, ms: u64) {
self.send_timeout = ms;
}
#[inline]
pub fn set_key(&mut self, _key: Key) {
// not-supported
// WebRTC uses built-in DTLS encryption for secure communication.
// DTLS handles key exchange and encryption automatically, so explicit key management is not required.
}
#[inline]
pub fn is_secured(&self) -> bool {
true
}
#[inline]
pub async fn send(&mut self, msg: &impl Message) -> ResultType<()> {
self.send_raw(msg.write_to_bytes()?).await
}
#[inline]
pub async fn send_raw(&mut self, msg: Vec<u8>) -> ResultType<()> {
self.send_bytes(Bytes::from(msg)).await
}
#[inline]
async fn wait_for_connect_result(&mut self) {
if *self.state_notify.borrow() {
return;
}
let _ = self.state_notify.changed().await;
}
pub async fn send_bytes(&mut self, bytes: Bytes) -> ResultType<()> {
if self.send_timeout > 0 {
match timeout(
Duration::from_millis(self.send_timeout),
self.wait_for_connect_result(),
)
.await
{
Ok(_) => {}
Err(_) => {
self.pc.close().await.ok();
return Err(Error::new(
ErrorKind::TimedOut,
"WebRTC send wait for connect timeout",
)
.into());
}
}
} else {
self.wait_for_connect_result().await;
}
let stream = self.stream.lock().await.clone();
stream.send(&bytes).await?;
Ok(())
}
#[inline]
pub async fn next(&mut self) -> Option<Result<BytesMut, Error>> {
self.wait_for_connect_result().await;
let stream = self.stream.lock().await.clone();
// TODO reuse buffer?
let mut buffer = BytesMut::zeroed(DATA_CHANNEL_BUFFER_SIZE as usize);
let dc = stream.detach().await.ok()?;
let n = match dc.read(&mut buffer).await {
Ok(n) => n,
Err(err) => {
self.pc.close().await.ok();
return Some(Err(Error::new(
ErrorKind::Other,
format!("data channel read error: {}", err),
)));
}
};
if n == 0 {
self.pc.close().await.ok();
return Some(Err(Error::new(
ErrorKind::Other,
"data channel read exited with 0 bytes",
)));
}
buffer.truncate(n);
Some(Ok(buffer))
}
#[inline]
pub async fn next_timeout(&mut self, ms: u64) -> Option<Result<BytesMut, Error>> {
match timeout(Duration::from_millis(ms), self.next()).await {
Ok(res) => res,
Err(_) => None,
}
}
}
pub fn is_webrtc_endpoint(endpoint: &str) -> bool {
// use sdp base64 json string as endpoint, or prefix webrtc:
endpoint.starts_with("webrtc://")
}
#[cfg(test)]
mod tests {
use crate::config;
use crate::webrtc::WebRTCStream;
use crate::webrtc::DEFAULT_ICE_SERVERS;
use webrtc::peer_connection::sdp::session_description::RTCSessionDescription;
#[test]
fn test_webrtc_ice_url() {
assert_eq!(
WebRTCStream::get_ice_server_from_url("turn://example.com:3478")
.unwrap_or_default()
.urls[0],
"turn:example.com:3478"
);
assert_eq!(
WebRTCStream::get_ice_server_from_url("turn://example.com")
.unwrap_or_default()
.urls[0],
"turn:example.com:3478"
);
assert_eq!(
WebRTCStream::get_ice_server_from_url("turn://123@example.com")
.unwrap_or_default()
.username,
"123"
);
assert_eq!(
WebRTCStream::get_ice_server_from_url("turn://123@example.com")
.unwrap_or_default()
.credential,
""
);
assert_eq!(
WebRTCStream::get_ice_server_from_url("turn://123:321@example.com")
.unwrap_or_default()
.credential,
"321"
);
assert_eq!(
WebRTCStream::get_ice_server_from_url("stun://example.com:3478")
.unwrap_or_default()
.urls[0],
"stun:example.com:3478"
);
assert_eq!(
WebRTCStream::get_ice_server_from_url("http://123:123@example.com:3478"),
None
);
config::Config::set_option("ice-servers".to_string(), "".to_string());
assert_eq!(
WebRTCStream::get_ice_servers()[0].urls[0],
DEFAULT_ICE_SERVERS[0].to_string()
);
config::Config::set_option(
"ice-servers".to_string(),
",stun://example.com,turn://example.com,sdf".to_string(),
);
assert_eq!(
WebRTCStream::get_ice_servers()[0].urls[0],
"stun:example.com:3478"
);
assert_eq!(
WebRTCStream::get_ice_servers()[1].urls[0],
"turn:example.com:3478"
);
assert_eq!(WebRTCStream::get_ice_servers().len(), 2);
config::Config::set_option(
"ice-servers".to_string(),
"".to_string(),
);
}
#[test]
fn test_webrtc_session_key() {
let mut sdp_str = "".to_owned();
assert_eq!(
WebRTCStream::get_key_for_sdp(
&RTCSessionDescription::offer(sdp_str).unwrap_or_default()
)
.unwrap_or_default(),
""
);
sdp_str = "\
v=0
o=- 7400546379179479477 208696200 IN IP4 0.0.0.0
s=-
t=0 0
a=fingerprint:sha-256 97:52:D6:1F:1E:87:6C:DA:B8:21:95:64:A5:85:89:FA:02:71:C7:4D:B3:FD:25:92:40:FB:6B:65:24:3C:79:88
a=group:BUNDLE 0
a=extmap-allow-mixed
m=application 9 UDP/DTLS/SCTP webrtc-datachannel
c=IN IP4 0.0.0.0
a=setup:actpass
a=mid:0
a=sendrecv
a=sctp-port:5000
a=ice-ufrag:RMWjjpXfpXbDPdMz
a=ice-pwd:BtIqlWHfwhsJdFiBROeLuEbNmYfHxRfT".to_owned();
assert_eq!(
WebRTCStream::get_key_for_sdp(
&RTCSessionDescription::offer(sdp_str).unwrap_or_default()
).unwrap_or_default(),
"sha-256 97:52:D6:1F:1E:87:6C:DA:B8:21:95:64:A5:85:89:FA:02:71:C7:4D:B3:FD:25:92:40:FB:6B:65:24:3C:79:88"
);
sdp_str = "\
v=0
o=- 7400546379179479477 208696200 IN IP4 0.0.0.0
s=-
t=0 0
a=group:BUNDLE 0
a=extmap-allow-mixed
m=application 9 UDP/DTLS/SCTP webrtc-datachannel
c=IN IP4 0.0.0.0
a=fingerprint:sha-256 97:52:D6:1F:1E:87:6C:DA:B8:21:95:64:A5:85:89:FA:02:71:C7:4D:B3:FD:25:92:40:FB:6B:65:24:3C:79:88
a=setup:actpass
a=mid:0
a=sendrecv
a=sctp-port:5000
a=ice-ufrag:RMWjjpXfpXbDPdMz
a=ice-pwd:BtIqlWHfwhsJdFiBROeLuEbNmYfHxRfT".to_owned();
assert_eq!(
WebRTCStream::get_key_for_sdp(
&RTCSessionDescription::offer(sdp_str).unwrap_or_default()
).unwrap_or_default(),
"sha-256 97:52:D6:1F:1E:87:6C:DA:B8:21:95:64:A5:85:89:FA:02:71:C7:4D:B3:FD:25:92:40:FB:6B:65:24:3C:79:88"
);
sdp_str = "\
v=0
o=- 7400546379179479477 208696200 IN IP4 0.0.0.0
s=-
t=0 0
a=group:BUNDLE 0
a=extmap-allow-mixed
m=application 9 UDP/DTLS/SCTP webrtc-datachannel
c=IN IP4 0.0.0.0
a=setup:actpass
a=mid:0
a=sendrecv
a=sctp-port:5000
a=ice-ufrag:RMWjjpXfpXbDPdMz
a=ice-pwd:BtIqlWHfwhsJdFiBROeLuEbNmYfHxRfT"
.to_owned();
assert!(
WebRTCStream::get_key_for_sdp(
&RTCSessionDescription::offer(sdp_str).unwrap_or_default()
)
.is_err(),
"can not find fingerprint attribute"
);
sdp_str = "\
v=0
o=- 7400546379179479477 208696200 IN IP4 0.0.0.0
s=-
t=0 0
a=group:BUNDLE 0
a=extmap-allow-mixed
m=audio 9 UDP/DTLS/SCTP webrtc-datachannel
c=IN IP4 0.0.0.0
a=fingerprint:sha-256 97:52:D6:1F:1E:87:6C:DA:B8:21:95:64:A5:85:89:FA:02:71:C7:4D:B3:FD:25:92:40:FB:6B:65:24:3C:79:88
a=setup:actpass
a=mid:0
a=sendrecv
a=sctp-port:5000
a=ice-ufrag:RMWjjpXfpXbDPdMz
a=ice-pwd:BtIqlWHfwhsJdFiBROeLuEbNmYfHxRfT".to_owned();
assert!(
WebRTCStream::get_key_for_sdp(
&RTCSessionDescription::offer(sdp_str).unwrap_or_default()
)
.is_err(),
"can not find datachannel fingerprint attribute"
);
assert!(
WebRTCStream::get_key_for_sdp(
&RTCSessionDescription::offer("".to_owned()).unwrap_or_default()
)
.is_err(),
"invalid sdp should error"
);
assert!(
WebRTCStream::get_key_for_sdp_json("{}").is_err(),
"empty sdp json should error"
);
assert!(
WebRTCStream::get_key_for_sdp_json("{ss}").is_err(),
"invalid sdp json should error"
);
let endpoint = "webrtc://eyJ0eXBlIjoiYW5zd2VyIiwic2RwIjoidj0wXHJcbm89LSA0MTA1NDk3NTY2NDgyMTQzODEwIDYwMzk1NzQw\
MCBJTiBJUDQgMC4wLjAuMFxyXG5zPS1cclxudD0wIDBcclxuYT1maW5nZXJwcmludDpzaGEtMjU2IDYxOjYwOjc0OjQwOjI4OkNFOjBCOjBDOjc1OjRCOj\
EwOjlBOkVFOjc3OkY1OjQ0OjU3Ojg0OjUxOkRCOjA0OjkyOjRBOjEwOjFDOjRFOjVGOjdFOkYxOkIzOjcxOjIyXHJcbmE9Z3JvdXA6QlVORExFIDBcclxu\
YT1leHRtYXAtYWxsb3ctbWl4ZWRcclxubT1hcHBsaWNhdGlvbiA5IFVEUC9EVExTL1NDVFAgd2VicnRjLWRhdGFjaGFubmVsXHJcbmM9SU4gSVA0IDAuMC\
4wLjBcclxuYT1zZXR1cDphY3RpdmVcclxuYT1taWQ6MFxyXG5hPXNlbmRyZWN2XHJcbmE9c2N0cC1wb3J0OjUwMDBcclxuYT1pY2UtdWZyYWc6SHlnU1Rr\
V2RsRlpHRG1XWlxyXG5hPWljZS1wd2Q6SkJneFZWaGZveVhHdHZha1VWcnBQeHVOSVpMU3llS1pcclxuYT1jYW5kaWRhdGU6OTYzOTg4MzQ4IDEgdWRwID\
IxMzA3MDY0MzEgMTkyLjE2OC4xLjIgNjQwMDcgdHlwIGhvc3RcclxuYT1jYW5kaWRhdGU6OTYzOTg4MzQ4IDIgdWRwIDIxMzA3MDY0MzEgMTkyLjE2OC4x\
LjIgNjQwMDcgdHlwIGhvc3RcclxuYT1jYW5kaWRhdGU6MTg2MTA0NTE5MCAxIHVkcCAxNjk0NDk4ODE1IDE0LjIxMi42OC4xMiAyNzAwNCB0eXAgc3JmbH\
ggcmFkZHIgMC4wLjAuMCBycG9ydCA2NDAwOFxyXG5hPWNhbmRpZGF0ZToxODYxMDQ1MTkwIDIgdWRwIDE2OTQ0OTg4MTUgMTQuMjEyLjY4LjEyIDI3MDA0\
IHR5cCBzcmZseCByYWRkciAwLjAuMC4wIHJwb3J0IDY0MDA4XHJcbmE9ZW5kLW9mLWNhbmRpZGF0ZXNcclxuIn0=".to_owned();
assert_eq!(
WebRTCStream::get_key_for_sdp_json(
&WebRTCStream::get_remote_offer(&endpoint).unwrap_or_default()
).unwrap_or_default(),
"sha-256 61:60:74:40:28:CE:0B:0C:75:4B:10:9A:EE:77:F5:44:57:84:51:DB:04:92:4A:10:1C:4E:5F:7E:F1:B3:71:22"
);
}
#[tokio::test]
async fn test_webrtc_new_stream() {
let mut endpoint = "webrtc://sdfsdf".to_owned();
assert!(
WebRTCStream::new(&endpoint, false, 10000).await.is_err(),
"invalid webrtc endpoint should error"
);
endpoint = "wss://sdfsdf".to_owned();
assert!(
WebRTCStream::new(&endpoint, false, 10000).await.is_err(),
"invalid webrtc endpoint should error"
);
assert!(
WebRTCStream::new("", false, 10000).await.is_ok(),
"local webrtc endpoint should ok"
);
endpoint = "webrtc://eyJ0eXBlIjoiYW5zd2VyIiwic2RwIjoidj0wXHJcbm89LSA0MTA1NDk3NTY2NDgyMTQzODEwIDYwMzk1NzQw\
MCBJTiBJUDQgMC4wLjAuMFxyXG5zPS1cclxudD0wIDBcclxuYT1maW5nZXJwcmludDpzaGEtMjU2IDYxOjYwOjc0OjQwOjI4OkNFOjBCOjBDOjc1OjRCOj\
EwOjlBOkVFOjc3OkY1OjQ0OjU3Ojg0OjUxOkRCOjA0OjkyOjRBOjEwOjFDOjRFOjVGOjdFOkYxOkIzOjcxOjIyXHJcbmE9Z3JvdXA6QlVORExFIDBcclxu\
YT1leHRtYXAtYWxsb3ctbWl4ZWRcclxubT1hcHBsaWNhdGlvbiA5IFVEUC9EVExTL1NDVFAgd2VicnRjLWRhdGFjaGFubmVsXHJcbmM9SU4gSVA0IDAuMC\
4wLjBcclxuYT1zZXR1cDphY3RpdmVcclxuYT1taWQ6MFxyXG5hPXNlbmRyZWN2XHJcbmE9c2N0cC1wb3J0OjUwMDBcclxuYT1pY2UtdWZyYWc6SHlnU1Rr\
V2RsRlpHRG1XWlxyXG5hPWljZS1wd2Q6SkJneFZWaGZveVhHdHZha1VWcnBQeHVOSVpMU3llS1pcclxuYT1jYW5kaWRhdGU6OTYzOTg4MzQ4IDEgdWRwID\
IxMzA3MDY0MzEgMTkyLjE2OC4xLjIgNjQwMDcgdHlwIGhvc3RcclxuYT1jYW5kaWRhdGU6OTYzOTg4MzQ4IDIgdWRwIDIxMzA3MDY0MzEgMTkyLjE2OC4x\
LjIgNjQwMDcgdHlwIGhvc3RcclxuYT1jYW5kaWRhdGU6MTg2MTA0NTE5MCAxIHVkcCAxNjk0NDk4ODE1IDE0LjIxMi42OC4xMiAyNzAwNCB0eXAgc3JmbH\
ggcmFkZHIgMC4wLjAuMCBycG9ydCA2NDAwOFxyXG5hPWNhbmRpZGF0ZToxODYxMDQ1MTkwIDIgdWRwIDE2OTQ0OTg4MTUgMTQuMjEyLjY4LjEyIDI3MDA0\
IHR5cCBzcmZseCByYWRkciAwLjAuMC4wIHJwb3J0IDY0MDA4XHJcbmE9ZW5kLW9mLWNhbmRpZGF0ZXNcclxuIn0=".to_owned();
assert!(
WebRTCStream::new(&endpoint, false, 10000).await.is_err(),
"connect to an 'answer' webrtc endpoint should error"
);
}
}
+531
View File
@@ -0,0 +1,531 @@
use crate::{
config::{
keys::OPTION_RELAY_SERVER, use_ws, Config, Socks5Server, RELAY_PORT, RENDEZVOUS_PORT,
},
protobuf::Message,
socket_client::split_host_port,
sodiumoxide::crypto::secretbox::Key,
tcp::Encrypt,
tls::{get_cached_tls_accept_invalid_cert, get_cached_tls_type, upsert_tls_cache, TlsType},
ResultType,
};
use anyhow::bail;
use async_recursion::async_recursion;
use bytes::{Bytes, BytesMut};
use futures::{SinkExt, StreamExt};
use std::{
io::{Error, ErrorKind},
net::SocketAddr,
sync::Arc,
time::Duration,
};
use tokio::{net::TcpStream, time::timeout};
use tokio_native_tls::native_tls::TlsConnector;
use tokio_tungstenite::{
connect_async_tls_with_config, tungstenite::protocol::Message as WsMessage, Connector,
MaybeTlsStream, WebSocketStream,
};
use tungstenite::client::IntoClientRequest;
use tungstenite::protocol::Role;
pub struct WsFramedStream {
stream: WebSocketStream<MaybeTlsStream<TcpStream>>,
addr: SocketAddr,
encrypt: Option<Encrypt>,
send_timeout: u64,
}
impl WsFramedStream {
#[inline]
fn get_connector(
tls_type: &TlsType,
danger_accept_invalid_certs: bool,
) -> ResultType<Option<Connector>> {
match tls_type {
TlsType::Plain => Ok(Some(Connector::Plain)),
TlsType::NativeTls => {
let connector = TlsConnector::builder()
.danger_accept_invalid_certs(danger_accept_invalid_certs)
.build()?;
Ok(Some(Connector::NativeTls(connector)))
}
TlsType::Rustls => {
let connector = match crate::verifier::client_config(danger_accept_invalid_certs) {
Ok(client_config) => Some(Connector::Rustls(Arc::new(client_config))),
Err(e) => {
log::warn!(
"Failed to get client config: {:?}, fallback to default connector",
e
);
None
}
};
Ok(connector)
}
}
}
async fn connect(
url: &str,
ms_timeout: u64,
) -> ResultType<WebSocketStream<MaybeTlsStream<TcpStream>>> {
// to-do: websocket proxy.
let tls_type = get_cached_tls_type(url);
let is_tls_type_cached = tls_type.is_some();
let tls_type = tls_type.unwrap_or(TlsType::Rustls);
let danger_accept_invalid_cert = get_cached_tls_accept_invalid_cert(&url);
Self::try_connect(
url,
ms_timeout,
tls_type,
is_tls_type_cached,
danger_accept_invalid_cert,
danger_accept_invalid_cert,
)
.await
}
#[async_recursion]
async fn try_connect(
url: &str,
ms_timeout: u64,
tls_type: TlsType,
is_tls_type_cached: bool,
danger_accept_invalid_cert: Option<bool>,
original_danger_accept_invalid_certs: Option<bool>,
) -> ResultType<WebSocketStream<MaybeTlsStream<TcpStream>>> {
let ws_config = None;
let disable_nagle = false;
let request = url
.into_client_request()
.map_err(|e| Error::new(ErrorKind::Other, e))?;
let connector =
Self::get_connector(&tls_type, danger_accept_invalid_cert.unwrap_or(false))?;
match timeout(
Duration::from_millis(ms_timeout),
connect_async_tls_with_config(request, ws_config, disable_nagle, connector),
)
.await?
{
Ok((ws_stream, _)) => {
upsert_tls_cache(url, tls_type, danger_accept_invalid_cert.unwrap_or(false));
Ok(ws_stream)
}
Err(e) => match (tls_type, is_tls_type_cached, danger_accept_invalid_cert) {
(TlsType::Rustls, _, None) => {
log::warn!(
"WebSocket connection with rustls-tls failed, try accept invalid certs: {}, {:?}",
url,
e
);
Self::try_connect(
url,
ms_timeout,
tls_type,
is_tls_type_cached,
Some(true),
original_danger_accept_invalid_certs,
)
.await
}
(TlsType::Rustls, false, Some(_)) => {
log::warn!(
"WebSocket connection with rustls-tls failed, try native-tls: {}, {:?}",
url,
e
);
Self::try_connect(
url,
ms_timeout,
TlsType::NativeTls,
is_tls_type_cached,
original_danger_accept_invalid_certs,
original_danger_accept_invalid_certs,
)
.await
}
(TlsType::NativeTls, _, None) => {
log::warn!(
"WebSocket connection with native-tls failed, try accept invalid certs: {}, {:?}",
url,
e
);
Self::try_connect(
url,
ms_timeout,
tls_type,
is_tls_type_cached,
Some(true),
original_danger_accept_invalid_certs,
)
.await
}
_ => {
log::error!(
"WebSocket connection failed with tls_type {:?}: {}, {:?}",
tls_type,
url,
e
);
bail!(e)
}
},
}
}
pub async fn new<T: AsRef<str>>(
url: T,
_local_addr: Option<SocketAddr>,
_proxy_conf: Option<&Socks5Server>,
ms_timeout: u64,
) -> ResultType<Self> {
let stream = Self::connect(url.as_ref(), ms_timeout).await?;
let addr = match stream.get_ref() {
MaybeTlsStream::Plain(tcp) => tcp.peer_addr()?,
MaybeTlsStream::NativeTls(tls) => tls.get_ref().get_ref().get_ref().peer_addr()?,
MaybeTlsStream::Rustls(tls) => tls.get_ref().0.peer_addr()?,
_ => return Err(Error::new(ErrorKind::Other, "Unsupported stream type").into()),
};
let ws = Self {
stream,
addr,
encrypt: None,
send_timeout: ms_timeout,
};
Ok(ws)
}
#[inline]
pub fn set_raw(&mut self) {
self.encrypt = None;
}
#[inline]
pub async fn from_tcp_stream(stream: TcpStream, addr: SocketAddr) -> ResultType<Self> {
let ws_stream =
WebSocketStream::from_raw_socket(MaybeTlsStream::Plain(stream), Role::Client, None)
.await;
Ok(Self {
stream: ws_stream,
addr,
encrypt: None,
send_timeout: 0,
})
}
#[inline]
pub fn local_addr(&self) -> SocketAddr {
self.addr
}
#[inline]
pub fn set_send_timeout(&mut self, ms: u64) {
self.send_timeout = ms;
}
#[inline]
pub fn set_key(&mut self, key: Key) {
self.encrypt = Some(Encrypt::new(key));
}
#[inline]
pub fn is_secured(&self) -> bool {
self.encrypt.is_some()
}
#[inline]
pub async fn send(&mut self, msg: &impl Message) -> ResultType<()> {
self.send_raw(msg.write_to_bytes()?).await
}
#[inline]
pub async fn send_raw(&mut self, msg: Vec<u8>) -> ResultType<()> {
let mut msg = msg;
if let Some(key) = self.encrypt.as_mut() {
msg = key.enc(&msg);
}
self.send_bytes(Bytes::from(msg)).await
}
pub async fn send_bytes(&mut self, bytes: Bytes) -> ResultType<()> {
let msg = WsMessage::Binary(bytes);
if self.send_timeout > 0 {
timeout(
Duration::from_millis(self.send_timeout),
self.stream.send(msg),
)
.await??
} else {
self.stream.send(msg).await?
};
Ok(())
}
#[inline]
pub async fn next(&mut self) -> Option<Result<BytesMut, Error>> {
while let Some(msg) = self.stream.next().await {
let msg = match msg {
Ok(msg) => msg,
Err(e) => {
log::error!("{}", e);
return Some(Err(Error::new(
ErrorKind::Other,
format!("WebSocket protocol error: {}", e),
)));
}
};
match msg {
WsMessage::Binary(data) => {
let mut bytes = BytesMut::from(&data[..]);
if let Some(key) = self.encrypt.as_mut() {
if let Err(err) = key.dec(&mut bytes) {
return Some(Err(err));
}
}
return Some(Ok(bytes));
}
WsMessage::Text(text) => {
let bytes = BytesMut::from(text.as_bytes());
return Some(Ok(bytes));
}
WsMessage::Close(_) => {
return None;
}
_ => {
continue;
}
}
}
None
}
#[inline]
pub async fn next_timeout(&mut self, ms: u64) -> Option<Result<BytesMut, Error>> {
match timeout(Duration::from_millis(ms), self.next()).await {
Ok(res) => res,
Err(_) => None,
}
}
}
pub fn is_ws_endpoint(endpoint: &str) -> bool {
endpoint.starts_with("ws://") || endpoint.starts_with("wss://")
}
/**
* Core function to convert an endpoint to WebSocket format
*
* Converts between different address formats:
* 1. IPv4 address with/without port -> ws://ipv4:port
* 2. IPv6 address with/without port -> ws://[ipv6]:port
* 3. Domain with/without port -> ws(s)://domain/ws/path
*
* @param endpoint The endpoint to convert
* @return The converted WebSocket endpoint
*/
pub fn check_ws(endpoint: &str) -> String {
if !use_ws() {
return endpoint.to_string();
}
if endpoint.is_empty() {
return endpoint.to_string();
}
if is_ws_endpoint(endpoint) {
return endpoint.to_string();
}
let Some((endpoint_host, endpoint_port)) = split_host_port(endpoint) else {
debug_assert!(false, "endpoint doesn't have port");
return endpoint.to_string();
};
let custom_rendezvous_server = Config::get_rendezvous_server();
let relay_server = Config::get_option(OPTION_RELAY_SERVER);
let rendezvous_port = split_host_port(&custom_rendezvous_server)
.map(|(_, p)| p)
.unwrap_or(RENDEZVOUS_PORT);
let relay_port = split_host_port(&relay_server)
.map(|(_, p)| p)
.unwrap_or(RELAY_PORT);
let (relay, dst_port) = if endpoint_port == rendezvous_port {
// rendezvous
(false, endpoint_port + 2)
} else if endpoint_port == rendezvous_port - 1 {
// online
(false, endpoint_port + 3)
} else if endpoint_port == relay_port || endpoint_port == rendezvous_port + 1 {
// relay
// https://github.com/rustdesk/rustdesk/blob/6ffbcd1375771f2482ec4810680623a269be70f1/src/rendezvous_mediator.rs#L615
// https://github.com/rustdesk/rustdesk-server/blob/235a3c326ceb665e941edb50ab79faa1208f7507/src/relay_server.rs#L83, based on relay port.
(true, endpoint_port + 2)
} else {
// fallback relay
// for controlling side, relay server is passed from the controlled side, not related to local config.
(true, endpoint_port + 2)
};
let (address, is_domain) = if crate::is_ip_str(endpoint) {
(format!("{}:{}", endpoint_host, dst_port), false)
} else {
let domain_path = if relay { "/ws/relay" } else { "/ws/id" };
(format!("{}{}", endpoint_host, domain_path), true)
};
let protocol = if is_domain {
let api_server = Config::get_option("api-server");
if api_server.starts_with("https") {
"wss"
} else {
"ws"
}
} else {
"ws"
};
format!("{}://{}", protocol, address)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::config::{keys, Config};
#[test]
fn test_check_ws() {
// enable websocket
Config::set_option(keys::OPTION_ALLOW_WEBSOCKET.to_string(), "Y".to_string());
// not set custom-rendezvous-server
Config::set_option("custom-rendezvous-server".to_string(), "".to_string());
Config::set_option("relay-server".to_string(), "".to_string());
Config::set_option("api-server".to_string(), "".to_string());
assert_eq!(check_ws("127.0.0.1:21115"), "ws://127.0.0.1:21118");
assert_eq!(check_ws("127.0.0.1:21116"), "ws://127.0.0.1:21118");
assert_eq!(check_ws("127.0.0.1:21117"), "ws://127.0.0.1:21119");
assert_eq!(check_ws("rustdesk.com:21115"), "ws://rustdesk.com/ws/id");
assert_eq!(check_ws("rustdesk.com:21116"), "ws://rustdesk.com/ws/id");
assert_eq!(check_ws("rustdesk.com:21117"), "ws://rustdesk.com/ws/relay");
// set relay-server without port
Config::set_option("relay-server".to_string(), "127.0.0.1".to_string());
Config::set_option(
"api-server".to_string(),
"https://api.rustdesk.com".to_string(),
);
assert_eq!(
check_ws("[0:0:0:0:0:0:0:1]:21115"),
"ws://[0:0:0:0:0:0:0:1]:21118"
);
assert_eq!(
check_ws("[0:0:0:0:0:0:0:1]:21116"),
"ws://[0:0:0:0:0:0:0:1]:21118"
);
assert_eq!(
check_ws("[0:0:0:0:0:0:0:1]:21117"),
"ws://[0:0:0:0:0:0:0:1]:21119"
);
assert_eq!(check_ws("rustdesk.com:21115"), "wss://rustdesk.com/ws/id");
assert_eq!(check_ws("rustdesk.com:21116"), "wss://rustdesk.com/ws/id");
assert_eq!(
check_ws("rustdesk.com:21117"),
"wss://rustdesk.com/ws/relay"
);
// set relay-server with default port
Config::set_option("relay-server".to_string(), "127.0.0.1:21117".to_string());
assert_eq!(check_ws("127.0.0.1:21115"), "ws://127.0.0.1:21118");
assert_eq!(check_ws("127.0.0.1:21116"), "ws://127.0.0.1:21118");
assert_eq!(check_ws("127.0.0.1:21117"), "ws://127.0.0.1:21119");
// set relay-server with custom port
Config::set_option("relay-server".to_string(), "127.0.0.1:34567".to_string());
assert_eq!(check_ws("rustdesk.com:21115"), "wss://rustdesk.com/ws/id");
assert_eq!(check_ws("rustdesk.com:21116"), "wss://rustdesk.com/ws/id");
assert_eq!(
check_ws("rustdesk.com:34567"),
"wss://rustdesk.com/ws/relay"
);
// set custom-rendezvous-server without port
Config::set_option(
"custom-rendezvous-server".to_string(),
"127.0.0.1".to_string(),
);
Config::set_option("relay-server".to_string(), "".to_string());
Config::set_option("api-server".to_string(), "".to_string());
assert_eq!(check_ws("127.0.0.1:21115"), "ws://127.0.0.1:21118");
assert_eq!(check_ws("127.0.0.1:21116"), "ws://127.0.0.1:21118");
assert_eq!(check_ws("127.0.0.1:21117"), "ws://127.0.0.1:21119");
// set relay-server without port
Config::set_option("relay-server".to_string(), "127.0.0.1".to_string());
assert_eq!(check_ws("127.0.0.1:21115"), "ws://127.0.0.1:21118");
assert_eq!(check_ws("127.0.0.1:21116"), "ws://127.0.0.1:21118");
assert_eq!(check_ws("127.0.0.1:21117"), "ws://127.0.0.1:21119");
// set relay-server with default port
Config::set_option("relay-server".to_string(), "127.0.0.1:21117".to_string());
assert_eq!(check_ws("127.0.0.1:21115"), "ws://127.0.0.1:21118");
assert_eq!(check_ws("127.0.0.1:21116"), "ws://127.0.0.1:21118");
assert_eq!(check_ws("127.0.0.1:21117"), "ws://127.0.0.1:21119");
// set relay-server with custom port
Config::set_option("relay-server".to_string(), "127.0.0.1:34567".to_string());
assert_eq!(check_ws("127.0.0.1:21115"), "ws://127.0.0.1:21118");
assert_eq!(check_ws("127.0.0.1:21116"), "ws://127.0.0.1:21118");
assert_eq!(check_ws("127.0.0.1:34567"), "ws://127.0.0.1:34569");
// set custom-rendezvous-server without default port
Config::set_option(
"custom-rendezvous-server".to_string(),
"127.0.0.1".to_string(),
);
Config::set_option("relay-server".to_string(), "".to_string());
Config::set_option("api-server".to_string(), "".to_string());
assert_eq!(check_ws("127.0.0.1:21115"), "ws://127.0.0.1:21118");
assert_eq!(check_ws("127.0.0.1:21116"), "ws://127.0.0.1:21118");
assert_eq!(check_ws("127.0.0.1:21117"), "ws://127.0.0.1:21119");
// set relay-server without port
Config::set_option("relay-server".to_string(), "127.0.0.1".to_string());
assert_eq!(check_ws("127.0.0.1:21115"), "ws://127.0.0.1:21118");
assert_eq!(check_ws("127.0.0.1:21116"), "ws://127.0.0.1:21118");
assert_eq!(check_ws("127.0.0.1:21117"), "ws://127.0.0.1:21119");
// set relay-server with default port
Config::set_option("relay-server".to_string(), "127.0.0.1:21117".to_string());
assert_eq!(check_ws("127.0.0.1:21115"), "ws://127.0.0.1:21118");
assert_eq!(check_ws("127.0.0.1:21116"), "ws://127.0.0.1:21118");
assert_eq!(check_ws("127.0.0.1:21117"), "ws://127.0.0.1:21119");
// set relay-server with custom port
Config::set_option("relay-server".to_string(), "127.0.0.1:34567".to_string());
assert_eq!(check_ws("127.0.0.1:21115"), "ws://127.0.0.1:21118");
assert_eq!(check_ws("127.0.0.1:21116"), "ws://127.0.0.1:21118");
assert_eq!(check_ws("127.0.0.1:34567"), "ws://127.0.0.1:34569");
// set custom-rendezvous-server with custom port
Config::set_option(
"custom-rendezvous-server".to_string(),
"127.0.0.1:23456".to_string(),
);
Config::set_option("relay-server".to_string(), "".to_string());
Config::set_option("api-server".to_string(), "".to_string());
assert_eq!(check_ws("127.0.0.1:23455"), "ws://127.0.0.1:23458");
assert_eq!(check_ws("127.0.0.1:23456"), "ws://127.0.0.1:23458");
assert_eq!(check_ws("127.0.0.1:23457"), "ws://127.0.0.1:23459");
// set relay-server without port
Config::set_option("relay-server".to_string(), "127.0.0.1".to_string());
assert_eq!(check_ws("127.0.0.1:23455"), "ws://127.0.0.1:23458");
assert_eq!(check_ws("127.0.0.1:23456"), "ws://127.0.0.1:23458");
assert_eq!(check_ws("127.0.0.1:21117"), "ws://127.0.0.1:21119");
// set relay-server with default port
Config::set_option("relay-server".to_string(), "127.0.0.1:21117".to_string());
assert_eq!(check_ws("127.0.0.1:23455"), "ws://127.0.0.1:23458");
assert_eq!(check_ws("127.0.0.1:23456"), "ws://127.0.0.1:23458");
assert_eq!(check_ws("127.0.0.1:21117"), "ws://127.0.0.1:21119");
// set relay-server with custom port
Config::set_option("relay-server".to_string(), "127.0.0.1:34567".to_string());
assert_eq!(check_ws("127.0.0.1:23455"), "ws://127.0.0.1:23458");
assert_eq!(check_ws("127.0.0.1:23456"), "ws://127.0.0.1:23458");
assert_eq!(check_ws("127.0.0.1:34567"), "ws://127.0.0.1:34569");
}
}
+9
View File
@@ -0,0 +1,9 @@
[package]
name = "libxdo-sys"
version = "0.11.0"
edition = "2021"
publish = false
description = "Dynamic loading wrapper for libxdo-sys that doesn't require libxdo at compile/link time"
[dependencies]
hbb_common = { path = "../hbb_common" }
+505
View File
@@ -0,0 +1,505 @@
//! Dynamic loading wrapper for libxdo.
//!
//! Provides the same API as libxdo-sys but loads libxdo at runtime,
//! allowing the program to run on systems without libxdo installed
//! (e.g., Wayland-only environments).
use hbb_common::{
libc::{c_char, c_int, c_uint},
libloading::{Library, Symbol},
log,
};
use std::sync::OnceLock;
pub use hbb_common::x11::xlib::{Display, Screen, Window};
#[repr(C)]
pub struct xdo_t {
_private: [u8; 0],
}
#[repr(C)]
pub struct charcodemap_t {
_private: [u8; 0],
}
#[repr(C)]
pub struct xdo_search_t {
_private: [u8; 0],
}
pub type useconds_t = c_uint;
pub const CURRENTWINDOW: Window = 0;
type FnXdoNew = unsafe extern "C" fn(*const c_char) -> *mut xdo_t;
type FnXdoNewWithOpenedDisplay =
unsafe extern "C" fn(*mut Display, *const c_char, c_int) -> *mut xdo_t;
type FnXdoFree = unsafe extern "C" fn(*mut xdo_t);
type FnXdoSendKeysequenceWindow =
unsafe extern "C" fn(*const xdo_t, Window, *const c_char, useconds_t) -> c_int;
type FnXdoSendKeysequenceWindowDown =
unsafe extern "C" fn(*const xdo_t, Window, *const c_char, useconds_t) -> c_int;
type FnXdoSendKeysequenceWindowUp =
unsafe extern "C" fn(*const xdo_t, Window, *const c_char, useconds_t) -> c_int;
type FnXdoEnterTextWindow =
unsafe extern "C" fn(*const xdo_t, Window, *const c_char, useconds_t) -> c_int;
type FnXdoClickWindow = unsafe extern "C" fn(*const xdo_t, Window, c_int) -> c_int;
type FnXdoMouseDown = unsafe extern "C" fn(*const xdo_t, Window, c_int) -> c_int;
type FnXdoMouseUp = unsafe extern "C" fn(*const xdo_t, Window, c_int) -> c_int;
type FnXdoMoveMouse = unsafe extern "C" fn(*const xdo_t, c_int, c_int, c_int) -> c_int;
type FnXdoMoveMouseRelative = unsafe extern "C" fn(*const xdo_t, c_int, c_int) -> c_int;
type FnXdoMoveMouseRelativeToWindow =
unsafe extern "C" fn(*const xdo_t, Window, c_int, c_int) -> c_int;
type FnXdoGetMouseLocation =
unsafe extern "C" fn(*const xdo_t, *mut c_int, *mut c_int, *mut c_int) -> c_int;
type FnXdoGetMouseLocation2 =
unsafe extern "C" fn(*const xdo_t, *mut c_int, *mut c_int, *mut c_int, *mut Window) -> c_int;
type FnXdoGetActiveWindow = unsafe extern "C" fn(*const xdo_t, *mut Window) -> c_int;
type FnXdoGetFocusedWindow = unsafe extern "C" fn(*const xdo_t, *mut Window) -> c_int;
type FnXdoGetFocusedWindowSane = unsafe extern "C" fn(*const xdo_t, *mut Window) -> c_int;
type FnXdoGetWindowLocation =
unsafe extern "C" fn(*const xdo_t, Window, *mut c_int, *mut c_int, *mut *mut Screen) -> c_int;
type FnXdoGetWindowSize =
unsafe extern "C" fn(*const xdo_t, Window, *mut c_uint, *mut c_uint) -> c_int;
type FnXdoGetInputState = unsafe extern "C" fn(*const xdo_t) -> c_uint;
type FnXdoActivateWindow = unsafe extern "C" fn(*const xdo_t, Window) -> c_int;
type FnXdoWaitForMouseMoveFrom = unsafe extern "C" fn(*const xdo_t, c_int, c_int) -> c_int;
type FnXdoWaitForMouseMoveTo = unsafe extern "C" fn(*const xdo_t, c_int, c_int) -> c_int;
type FnXdoSetWindowClass =
unsafe extern "C" fn(*const xdo_t, Window, *const c_char, *const c_char) -> c_int;
type FnXdoSearchWindows =
unsafe extern "C" fn(*const xdo_t, *const xdo_search_t, *mut *mut Window, *mut c_uint) -> c_int;
struct XdoLib {
_lib: Library,
xdo_new: FnXdoNew,
xdo_new_with_opened_display: Option<FnXdoNewWithOpenedDisplay>,
xdo_free: FnXdoFree,
xdo_send_keysequence_window: FnXdoSendKeysequenceWindow,
xdo_send_keysequence_window_down: Option<FnXdoSendKeysequenceWindowDown>,
xdo_send_keysequence_window_up: Option<FnXdoSendKeysequenceWindowUp>,
xdo_enter_text_window: Option<FnXdoEnterTextWindow>,
xdo_click_window: Option<FnXdoClickWindow>,
xdo_mouse_down: Option<FnXdoMouseDown>,
xdo_mouse_up: Option<FnXdoMouseUp>,
xdo_move_mouse: Option<FnXdoMoveMouse>,
xdo_move_mouse_relative: Option<FnXdoMoveMouseRelative>,
xdo_move_mouse_relative_to_window: Option<FnXdoMoveMouseRelativeToWindow>,
xdo_get_mouse_location: Option<FnXdoGetMouseLocation>,
xdo_get_mouse_location2: Option<FnXdoGetMouseLocation2>,
xdo_get_active_window: Option<FnXdoGetActiveWindow>,
xdo_get_focused_window: Option<FnXdoGetFocusedWindow>,
xdo_get_focused_window_sane: Option<FnXdoGetFocusedWindowSane>,
xdo_get_window_location: Option<FnXdoGetWindowLocation>,
xdo_get_window_size: Option<FnXdoGetWindowSize>,
xdo_get_input_state: Option<FnXdoGetInputState>,
xdo_activate_window: Option<FnXdoActivateWindow>,
xdo_wait_for_mouse_move_from: Option<FnXdoWaitForMouseMoveFrom>,
xdo_wait_for_mouse_move_to: Option<FnXdoWaitForMouseMoveTo>,
xdo_set_window_class: Option<FnXdoSetWindowClass>,
xdo_search_windows: Option<FnXdoSearchWindows>,
}
impl XdoLib {
fn load() -> Option<Self> {
// https://github.com/rustdesk/rustdesk/issues/13711
const LIB_NAMES: [&str; 3] = ["libxdo.so.4", "libxdo.so.3", "libxdo.so"];
unsafe {
let (lib, lib_name) = LIB_NAMES
.iter()
.find_map(|name| Library::new(name).ok().map(|lib| (lib, *name)))?;
log::info!("libxdo-sys Loaded {}", lib_name);
let xdo_new: FnXdoNew = *lib.get(b"xdo_new").ok()?;
let xdo_free: FnXdoFree = *lib.get(b"xdo_free").ok()?;
let xdo_send_keysequence_window: FnXdoSendKeysequenceWindow =
*lib.get(b"xdo_send_keysequence_window").ok()?;
let xdo_new_with_opened_display = lib
.get(b"xdo_new_with_opened_display")
.ok()
.map(|s: Symbol<FnXdoNewWithOpenedDisplay>| *s);
let xdo_send_keysequence_window_down = lib
.get(b"xdo_send_keysequence_window_down")
.ok()
.map(|s: Symbol<FnXdoSendKeysequenceWindowDown>| *s);
let xdo_send_keysequence_window_up = lib
.get(b"xdo_send_keysequence_window_up")
.ok()
.map(|s: Symbol<FnXdoSendKeysequenceWindowUp>| *s);
let xdo_enter_text_window = lib
.get(b"xdo_enter_text_window")
.ok()
.map(|s: Symbol<FnXdoEnterTextWindow>| *s);
let xdo_click_window = lib
.get(b"xdo_click_window")
.ok()
.map(|s: Symbol<FnXdoClickWindow>| *s);
let xdo_mouse_down = lib
.get(b"xdo_mouse_down")
.ok()
.map(|s: Symbol<FnXdoMouseDown>| *s);
let xdo_mouse_up = lib
.get(b"xdo_mouse_up")
.ok()
.map(|s: Symbol<FnXdoMouseUp>| *s);
let xdo_move_mouse = lib
.get(b"xdo_move_mouse")
.ok()
.map(|s: Symbol<FnXdoMoveMouse>| *s);
let xdo_move_mouse_relative = lib
.get(b"xdo_move_mouse_relative")
.ok()
.map(|s: Symbol<FnXdoMoveMouseRelative>| *s);
let xdo_move_mouse_relative_to_window = lib
.get(b"xdo_move_mouse_relative_to_window")
.ok()
.map(|s: Symbol<FnXdoMoveMouseRelativeToWindow>| *s);
let xdo_get_mouse_location = lib
.get(b"xdo_get_mouse_location")
.ok()
.map(|s: Symbol<FnXdoGetMouseLocation>| *s);
let xdo_get_mouse_location2 = lib
.get(b"xdo_get_mouse_location2")
.ok()
.map(|s: Symbol<FnXdoGetMouseLocation2>| *s);
let xdo_get_active_window = lib
.get(b"xdo_get_active_window")
.ok()
.map(|s: Symbol<FnXdoGetActiveWindow>| *s);
let xdo_get_focused_window = lib
.get(b"xdo_get_focused_window")
.ok()
.map(|s: Symbol<FnXdoGetFocusedWindow>| *s);
let xdo_get_focused_window_sane = lib
.get(b"xdo_get_focused_window_sane")
.ok()
.map(|s: Symbol<FnXdoGetFocusedWindowSane>| *s);
let xdo_get_window_location = lib
.get(b"xdo_get_window_location")
.ok()
.map(|s: Symbol<FnXdoGetWindowLocation>| *s);
let xdo_get_window_size = lib
.get(b"xdo_get_window_size")
.ok()
.map(|s: Symbol<FnXdoGetWindowSize>| *s);
let xdo_get_input_state = lib
.get(b"xdo_get_input_state")
.ok()
.map(|s: Symbol<FnXdoGetInputState>| *s);
let xdo_activate_window = lib
.get(b"xdo_activate_window")
.ok()
.map(|s: Symbol<FnXdoActivateWindow>| *s);
let xdo_wait_for_mouse_move_from = lib
.get(b"xdo_wait_for_mouse_move_from")
.ok()
.map(|s: Symbol<FnXdoWaitForMouseMoveFrom>| *s);
let xdo_wait_for_mouse_move_to = lib
.get(b"xdo_wait_for_mouse_move_to")
.ok()
.map(|s: Symbol<FnXdoWaitForMouseMoveTo>| *s);
let xdo_set_window_class = lib
.get(b"xdo_set_window_class")
.ok()
.map(|s: Symbol<FnXdoSetWindowClass>| *s);
let xdo_search_windows = lib
.get(b"xdo_search_windows")
.ok()
.map(|s: Symbol<FnXdoSearchWindows>| *s);
Some(Self {
_lib: lib,
xdo_new,
xdo_new_with_opened_display,
xdo_free,
xdo_send_keysequence_window,
xdo_send_keysequence_window_down,
xdo_send_keysequence_window_up,
xdo_enter_text_window,
xdo_click_window,
xdo_mouse_down,
xdo_mouse_up,
xdo_move_mouse,
xdo_move_mouse_relative,
xdo_move_mouse_relative_to_window,
xdo_get_mouse_location,
xdo_get_mouse_location2,
xdo_get_active_window,
xdo_get_focused_window,
xdo_get_focused_window_sane,
xdo_get_window_location,
xdo_get_window_size,
xdo_get_input_state,
xdo_activate_window,
xdo_wait_for_mouse_move_from,
xdo_wait_for_mouse_move_to,
xdo_set_window_class,
xdo_search_windows,
})
}
}
}
static XDO_LIB: OnceLock<Option<XdoLib>> = OnceLock::new();
fn get_lib() -> Option<&'static XdoLib> {
XDO_LIB
.get_or_init(|| {
let lib = XdoLib::load();
if lib.is_none() {
log::info!("libxdo-sys libxdo not found, xdo functions will be disabled");
}
lib
})
.as_ref()
}
pub unsafe extern "C" fn xdo_new(display: *const c_char) -> *mut xdo_t {
get_lib().map_or(std::ptr::null_mut(), |lib| (lib.xdo_new)(display))
}
pub unsafe extern "C" fn xdo_new_with_opened_display(
xdpy: *mut Display,
display: *const c_char,
close_display_when_freed: c_int,
) -> *mut xdo_t {
get_lib()
.and_then(|lib| lib.xdo_new_with_opened_display)
.map_or(std::ptr::null_mut(), |f| {
f(xdpy, display, close_display_when_freed)
})
}
pub unsafe extern "C" fn xdo_free(xdo: *mut xdo_t) {
if xdo.is_null() {
return;
}
if let Some(lib) = get_lib() {
(lib.xdo_free)(xdo);
}
}
pub unsafe extern "C" fn xdo_send_keysequence_window(
xdo: *const xdo_t,
window: Window,
keysequence: *const c_char,
delay: useconds_t,
) -> c_int {
get_lib().map_or(1, |lib| {
(lib.xdo_send_keysequence_window)(xdo, window, keysequence, delay)
})
}
pub unsafe extern "C" fn xdo_send_keysequence_window_down(
xdo: *const xdo_t,
window: Window,
keysequence: *const c_char,
delay: useconds_t,
) -> c_int {
get_lib()
.and_then(|lib| lib.xdo_send_keysequence_window_down)
.map_or(1, |f| f(xdo, window, keysequence, delay))
}
pub unsafe extern "C" fn xdo_send_keysequence_window_up(
xdo: *const xdo_t,
window: Window,
keysequence: *const c_char,
delay: useconds_t,
) -> c_int {
get_lib()
.and_then(|lib| lib.xdo_send_keysequence_window_up)
.map_or(1, |f| f(xdo, window, keysequence, delay))
}
pub unsafe extern "C" fn xdo_enter_text_window(
xdo: *const xdo_t,
window: Window,
string: *const c_char,
delay: useconds_t,
) -> c_int {
get_lib()
.and_then(|lib| lib.xdo_enter_text_window)
.map_or(1, |f| f(xdo, window, string, delay))
}
pub unsafe extern "C" fn xdo_click_window(
xdo: *const xdo_t,
window: Window,
button: c_int,
) -> c_int {
get_lib()
.and_then(|lib| lib.xdo_click_window)
.map_or(1, |f| f(xdo, window, button))
}
pub unsafe extern "C" fn xdo_mouse_down(xdo: *const xdo_t, window: Window, button: c_int) -> c_int {
get_lib()
.and_then(|lib| lib.xdo_mouse_down)
.map_or(1, |f| f(xdo, window, button))
}
pub unsafe extern "C" fn xdo_mouse_up(xdo: *const xdo_t, window: Window, button: c_int) -> c_int {
get_lib()
.and_then(|lib| lib.xdo_mouse_up)
.map_or(1, |f| f(xdo, window, button))
}
pub unsafe extern "C" fn xdo_move_mouse(
xdo: *const xdo_t,
x: c_int,
y: c_int,
screen: c_int,
) -> c_int {
get_lib()
.and_then(|lib| lib.xdo_move_mouse)
.map_or(1, |f| f(xdo, x, y, screen))
}
pub unsafe extern "C" fn xdo_move_mouse_relative(xdo: *const xdo_t, x: c_int, y: c_int) -> c_int {
get_lib()
.and_then(|lib| lib.xdo_move_mouse_relative)
.map_or(1, |f| f(xdo, x, y))
}
pub unsafe extern "C" fn xdo_move_mouse_relative_to_window(
xdo: *const xdo_t,
window: Window,
x: c_int,
y: c_int,
) -> c_int {
get_lib()
.and_then(|lib| lib.xdo_move_mouse_relative_to_window)
.map_or(1, |f| f(xdo, window, x, y))
}
pub unsafe extern "C" fn xdo_get_mouse_location(
xdo: *const xdo_t,
x: *mut c_int,
y: *mut c_int,
screen_num: *mut c_int,
) -> c_int {
get_lib()
.and_then(|lib| lib.xdo_get_mouse_location)
.map_or(1, |f| f(xdo, x, y, screen_num))
}
pub unsafe extern "C" fn xdo_get_mouse_location2(
xdo: *const xdo_t,
x: *mut c_int,
y: *mut c_int,
screen_num: *mut c_int,
window: *mut Window,
) -> c_int {
get_lib()
.and_then(|lib| lib.xdo_get_mouse_location2)
.map_or(1, |f| f(xdo, x, y, screen_num, window))
}
pub unsafe extern "C" fn xdo_get_active_window(
xdo: *const xdo_t,
window_ret: *mut Window,
) -> c_int {
get_lib()
.and_then(|lib| lib.xdo_get_active_window)
.map_or(1, |f| f(xdo, window_ret))
}
pub unsafe extern "C" fn xdo_get_focused_window(
xdo: *const xdo_t,
window_ret: *mut Window,
) -> c_int {
get_lib()
.and_then(|lib| lib.xdo_get_focused_window)
.map_or(1, |f| f(xdo, window_ret))
}
pub unsafe extern "C" fn xdo_get_focused_window_sane(
xdo: *const xdo_t,
window_ret: *mut Window,
) -> c_int {
get_lib()
.and_then(|lib| lib.xdo_get_focused_window_sane)
.map_or(1, |f| f(xdo, window_ret))
}
pub unsafe extern "C" fn xdo_get_window_location(
xdo: *const xdo_t,
window: Window,
x: *mut c_int,
y: *mut c_int,
screen_ret: *mut *mut Screen,
) -> c_int {
get_lib()
.and_then(|lib| lib.xdo_get_window_location)
.map_or(1, |f| f(xdo, window, x, y, screen_ret))
}
pub unsafe extern "C" fn xdo_get_window_size(
xdo: *const xdo_t,
window: Window,
width: *mut c_uint,
height: *mut c_uint,
) -> c_int {
get_lib()
.and_then(|lib| lib.xdo_get_window_size)
.map_or(1, |f| f(xdo, window, width, height))
}
pub unsafe extern "C" fn xdo_get_input_state(xdo: *const xdo_t) -> c_uint {
get_lib()
.and_then(|lib| lib.xdo_get_input_state)
.map_or(0, |f| f(xdo))
}
pub unsafe extern "C" fn xdo_activate_window(xdo: *const xdo_t, wid: Window) -> c_int {
get_lib()
.and_then(|lib| lib.xdo_activate_window)
.map_or(1, |f| f(xdo, wid))
}
pub unsafe extern "C" fn xdo_wait_for_mouse_move_from(
xdo: *const xdo_t,
origin_x: c_int,
origin_y: c_int,
) -> c_int {
get_lib()
.and_then(|lib| lib.xdo_wait_for_mouse_move_from)
.map_or(1, |f| f(xdo, origin_x, origin_y))
}
pub unsafe extern "C" fn xdo_wait_for_mouse_move_to(
xdo: *const xdo_t,
dest_x: c_int,
dest_y: c_int,
) -> c_int {
get_lib()
.and_then(|lib| lib.xdo_wait_for_mouse_move_to)
.map_or(1, |f| f(xdo, dest_x, dest_y))
}
pub unsafe extern "C" fn xdo_set_window_class(
xdo: *const xdo_t,
wid: Window,
name: *const c_char,
class: *const c_char,
) -> c_int {
get_lib()
.and_then(|lib| lib.xdo_set_window_class)
.map_or(1, |f| f(xdo, wid, name, class))
}
pub unsafe extern "C" fn xdo_search_windows(
xdo: *const xdo_t,
search: *const xdo_search_t,
windowlist_ret: *mut *mut Window,
nwindows_ret: *mut c_uint,
) -> c_int {
get_lib()
.and_then(|lib| lib.xdo_search_windows)
.map_or(1, |f| f(xdo, search, windowlist_ret, nwindows_ret))
}
+3
View File
@@ -0,0 +1,3 @@
/target
*.exe
*.bin
+39
View File
@@ -0,0 +1,39 @@
[package]
name = "rustdesk-portable-packer"
version = "1.4.6"
edition = "2021"
description = "RustDesk Remote Desktop"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
build = "build.rs"
[dependencies]
brotli = "3.4"
dirs = "5.0"
md5 = "0.7"
winapi = { version = "0.3", features = ["winbase"] }
[target.'cfg(target_os = "windows")'.dependencies]
windows = { version = "0.61", features = [
"Wdk",
"Wdk_System",
"Wdk_System_SystemServices",
"Win32",
"Win32_System",
"Win32_System_SystemInformation",
] }
native-windows-gui = {version = "1.0", default-features = false, features = ["animation-timer", "image-decoder"]}
[package.metadata.winres]
LegalCopyright = "Copyright © 2025 cStudio GmbH. All rights reserved."
ProductName = "RustDesk"
OriginalFilename = "rustdesk.exe"
FileDescription = "RustDesk Remote Desktop"
#ProductVersion = ""
[target.'cfg(target_os="windows")'.build-dependencies]
winres = "0.1"
winapi = { version = "0.3", features = [ "winnt", "pdh", "synchapi" ] }
+20
View File
@@ -0,0 +1,20 @@
fn main() {
#[cfg(windows)]
{
use std::io::Write;
let mut res = winres::WindowsResource::new();
res.set_icon("../../res/icon.ico")
.set_language(winapi::um::winnt::MAKELANGID(
winapi::um::winnt::LANG_ENGLISH,
winapi::um::winnt::SUBLANG_ENGLISH_US,
))
.set_manifest_file("../../res/manifest.xml");
match res.compile() {
Err(e) => {
write!(std::io::stderr(), "{}", e).unwrap();
std::process::exit(1);
}
Ok(_) => {}
}
}
}
+108
View File
@@ -0,0 +1,108 @@
#!/usr/bin/env python3
import os
import optparse
from hashlib import md5
import brotli
import datetime
# 4GB maximum
length_count = 4
# encoding
encoding = 'utf-8'
# output: {path: (compressed_data, file_md5)}
def generate_md5_table(folder: str, level) -> dict:
res: dict = dict()
curdir = os.curdir
os.chdir(folder)
for root, _, files in os.walk('.'):
# remove ./
for f in files:
md5_generator = md5()
full_path = os.path.join(root, f)
print(f"Processing {full_path}...")
f = open(full_path, "rb")
content = f.read()
content_compressed = brotli.compress(
content, quality=level)
md5_generator.update(content)
md5_code = md5_generator.hexdigest().encode(encoding=encoding)
res[full_path] = (content_compressed, md5_code)
os.chdir(curdir)
return res
def write_package_metadata(md5_table: dict, output_folder: str, exe: str):
output_path = os.path.join(output_folder, "data.bin")
with open(output_path, "wb") as f:
f.write("rustdesk".encode(encoding=encoding))
for path in md5_table.keys():
(compressed_data, md5_code) = md5_table[path]
data_length = len(compressed_data)
path = path.encode(encoding=encoding)
# path length & path
f.write((len(path)).to_bytes(length=length_count, byteorder='big'))
f.write(path)
# data length & compressed data
f.write(data_length.to_bytes(
length=length_count, byteorder='big'))
f.write(compressed_data)
# md5 code
f.write(md5_code)
# end
f.write("rustdesk".encode(encoding=encoding))
# executable
f.write(exe.encode(encoding='utf-8'))
print(f"Metadata has been written to {output_path}")
def write_app_metadata(output_folder: str):
output_path = os.path.join(output_folder, "app_metadata.toml")
with open(output_path, "w") as f:
f.write(f"timestamp = {int(datetime.datetime.now().timestamp() * 1000)}\n")
print(f"App metadata has been written to {output_path}")
def build_portable(output_folder: str, target: str):
os.chdir(output_folder)
if target:
os.system("cargo build --release --target " + target)
else:
os.system("cargo build --release")
# Linux: python3 generate.py -f ../rustdesk-portable-packer/test -o . -e ./test/main.py
# Windows: python3 .\generate.py -f ..\rustdesk\flutter\build\windows\runner\Debug\ -o . -e ..\rustdesk\flutter\build\windows\runner\Debug\rustdesk.exe
if __name__ == '__main__':
parser = optparse.OptionParser()
parser.add_option("-f", "--folder", dest="folder",
help="folder to compress")
parser.add_option("-o", "--output", dest="output_folder",
help="the root of portable packer project, default is './'")
parser.add_option("-e", "--executable", dest="executable",
help="specify startup file in --folder, default is rustdesk.exe")
parser.add_option("-t", "--target", dest="target",
help="the target used by cargo")
parser.add_option("-l", "--level", dest="level", type="int",
help="compression level, default is 11, highest", default=11)
(options, args) = parser.parse_args()
folder = options.folder or './rustdesk'
output_folder = os.path.abspath(options.output_folder or './')
if not options.executable:
options.executable = 'rustdesk.exe'
if not options.executable.startswith(folder):
options.executable = folder + '/' + options.executable
exe: str = os.path.abspath(options.executable)
if not exe.startswith(os.path.abspath(folder)):
print("The executable must locate in source folder")
exit(-1)
exe = '.' + exe[len(os.path.abspath(folder)):]
print("Executable path: " + exe)
print("Compression level: " + str(options.level))
md5_table = generate_md5_table(folder, options.level)
write_package_metadata(md5_table, output_folder, exe)
write_app_metadata(output_folder)
build_portable(output_folder, options.target)
+1
View File
@@ -0,0 +1 @@
brotli
+139
View File
@@ -0,0 +1,139 @@
use std::{
fs::{self},
io::{Cursor, Read},
path::Path,
};
#[cfg(windows)]
const BIN_DATA: &[u8] = include_bytes!("../data.bin");
#[cfg(not(windows))]
const BIN_DATA: &[u8] = &[];
// 4bytes
const LENGTH: usize = 4;
const IDENTIFIER_LENGTH: usize = 8;
const MD5_LENGTH: usize = 32;
const BUF_SIZE: usize = 4096;
pub(crate) struct BinaryData {
pub md5_code: &'static [u8],
// compressed gzip data
pub raw: &'static [u8],
pub path: String,
}
pub(crate) struct BinaryReader {
pub files: Vec<BinaryData>,
pub exe: String,
}
impl Default for BinaryReader {
fn default() -> Self {
let (files, exe) = BinaryReader::read();
Self { files, exe }
}
}
impl BinaryData {
fn decompress(&self) -> Vec<u8> {
let cursor = Cursor::new(self.raw);
let mut decoder = brotli::Decompressor::new(cursor, BUF_SIZE);
let mut buf = Vec::new();
decoder.read_to_end(&mut buf).ok();
buf
}
pub fn write_to_file(&self, prefix: &Path) {
let p = prefix.join(&self.path);
if let Some(parent) = p.parent() {
if !parent.exists() {
let _ = fs::create_dir_all(parent);
}
}
if p.exists() {
// check md5
let f = fs::read(p.clone()).unwrap_or_default();
let digest = format!("{:x}", md5::compute(&f));
let md5_record = String::from_utf8_lossy(self.md5_code);
if digest == md5_record {
// same, skip this file
println!("skip {}", &self.path);
return;
} else {
println!("writing {}", p.display());
println!("{} -> {}", md5_record, digest)
}
}
let _ = fs::write(p, self.decompress());
}
}
impl BinaryReader {
fn read() -> (Vec<BinaryData>, String) {
let mut base: usize = 0;
let mut parsed = vec![];
assert!(BIN_DATA.len() > IDENTIFIER_LENGTH, "bin data invalid!");
let mut iden = String::from_utf8_lossy(&BIN_DATA[base..base + IDENTIFIER_LENGTH]);
if iden != "rustdesk" {
panic!("bin file is not valid!");
}
base += IDENTIFIER_LENGTH;
loop {
iden = String::from_utf8_lossy(&BIN_DATA[base..base + IDENTIFIER_LENGTH]);
if iden == "rustdesk" {
base += IDENTIFIER_LENGTH;
break;
}
// start reading
let mut offset = 0;
let path_length = u32::from_be_bytes([
BIN_DATA[base + offset],
BIN_DATA[base + offset + 1],
BIN_DATA[base + offset + 2],
BIN_DATA[base + offset + 3],
]) as usize;
offset += LENGTH;
let path =
String::from_utf8_lossy(&BIN_DATA[base + offset..base + offset + path_length])
.to_string();
offset += path_length;
// file sz
let file_length = u32::from_be_bytes([
BIN_DATA[base + offset],
BIN_DATA[base + offset + 1],
BIN_DATA[base + offset + 2],
BIN_DATA[base + offset + 3],
]) as usize;
offset += LENGTH;
let raw = &BIN_DATA[base + offset..base + offset + file_length];
offset += file_length;
// md5
let md5 = &BIN_DATA[base + offset..base + offset + MD5_LENGTH];
offset += MD5_LENGTH;
parsed.push(BinaryData {
md5_code: md5,
raw: raw,
path: path,
});
base += offset;
}
// executable
let executable = String::from_utf8_lossy(&BIN_DATA[base..]).to_string();
(parsed, executable)
}
#[cfg(linux)]
pub fn configure_permission(&self, prefix: &Path) {
use std::os::unix::prelude::PermissionsExt;
let exe_path = prefix.join(&self.exe);
if exe_path.exists() {
if let Ok(f) = File::open(exe_path) {
if let Ok(meta) = f.metadata() {
let mut permissions = meta.permissions();
permissions.set_mode(0o755);
f.set_permissions(permissions).ok();
}
}
}
}
}
+248
View File
@@ -0,0 +1,248 @@
#![windows_subsystem = "windows"]
use std::{
path::{Path, PathBuf},
process::{Command, Stdio},
};
use bin_reader::BinaryReader;
pub mod bin_reader;
#[cfg(windows)]
mod ui;
#[cfg(windows)]
const APP_METADATA: &[u8] = include_bytes!("../app_metadata.toml");
#[cfg(not(windows))]
const APP_METADATA: &[u8] = &[];
const APP_METADATA_CONFIG: &str = "meta.toml";
const META_LINE_PREFIX_TIMESTAMP: &str = "timestamp = ";
const APP_PREFIX: &str = "rustdesk";
const APPNAME_RUNTIME_ENV_KEY: &str = "RUSTDESK_APPNAME";
#[cfg(windows)]
const SET_FOREGROUND_WINDOW_ENV_KEY: &str = "SET_FOREGROUND_WINDOW";
fn is_timestamp_matches(dir: &Path, ts: &mut u64) -> bool {
let Ok(app_metadata) = std::str::from_utf8(APP_METADATA) else {
return true;
};
for line in app_metadata.lines() {
if line.starts_with(META_LINE_PREFIX_TIMESTAMP) {
if let Ok(stored_ts) = line.replace(META_LINE_PREFIX_TIMESTAMP, "").parse::<u64>() {
*ts = stored_ts;
break;
}
}
}
if *ts == 0 {
return true;
}
if let Ok(content) = std::fs::read_to_string(dir.join(APP_METADATA_CONFIG)) {
for line in content.lines() {
if line.starts_with(META_LINE_PREFIX_TIMESTAMP) {
if let Ok(stored_ts) = line.replace(META_LINE_PREFIX_TIMESTAMP, "").parse::<u64>() {
return *ts == stored_ts;
}
}
}
}
false
}
fn write_meta(dir: &Path, ts: u64) {
let meta_file = dir.join(APP_METADATA_CONFIG);
if ts != 0 {
let content = format!("{}{}", META_LINE_PREFIX_TIMESTAMP, ts);
// Ignore is ok here
let _ = std::fs::write(meta_file, content);
}
}
fn setup(
reader: BinaryReader,
dir: Option<PathBuf>,
clear: bool,
_args: &Vec<String>,
_ui: &mut bool,
) -> Option<PathBuf> {
let dir = if let Some(dir) = dir {
dir
} else {
// home dir
if let Some(dir) = dirs::data_local_dir() {
dir.join(APP_PREFIX)
} else {
eprintln!("not found data local dir");
return None;
}
};
let mut ts = 0;
if clear || !is_timestamp_matches(&dir, &mut ts) {
#[cfg(windows)]
if _args.is_empty() {
*_ui = true;
ui::setup();
}
std::fs::remove_dir_all(&dir).ok();
}
for file in reader.files.iter() {
file.write_to_file(&dir);
}
write_meta(&dir, ts);
#[cfg(windows)]
win::copy_runtime_broker(&dir);
#[cfg(linux)]
reader.configure_permission(&dir);
Some(dir.join(&reader.exe))
}
fn use_null_stdio() -> bool {
#[cfg(windows)]
{
// When running in CMD on Windows 7, using Stdio::inherit() with spawn returns an "invalid handle" error.
// Since using Stdio::null() didnt cause any issues, and determining whether the program is launched from CMD or by double-clicking would require calling more APIs during startup, we also use Stdio::null() when launched by double-clicking on Windows 7.
let is_windows_7 = is_windows_7();
println!("is windows7: {}", is_windows_7);
return is_windows_7;
}
#[cfg(not(windows))]
false
}
#[cfg(windows)]
fn is_windows_7() -> bool {
use windows::Wdk::System::SystemServices::RtlGetVersion;
use windows::Win32::System::SystemInformation::OSVERSIONINFOW;
unsafe {
let mut version_info = OSVERSIONINFOW::default();
version_info.dwOSVersionInfoSize = std::mem::size_of::<OSVERSIONINFOW>() as u32;
if RtlGetVersion(&mut version_info).is_ok() {
// Windows 7 is version 6.1
println!(
"Windows version: {}.{}",
version_info.dwMajorVersion, version_info.dwMinorVersion
);
return version_info.dwMajorVersion == 6 && version_info.dwMinorVersion == 1;
}
}
false
}
fn execute(path: PathBuf, args: Vec<String>, _ui: bool) {
println!("executing {}", path.display());
// setup env
let exe = std::env::current_exe().unwrap_or_default();
let exe_name = exe.file_name().unwrap_or_default();
// run executable
let mut cmd = Command::new(path);
cmd.args(args);
#[cfg(windows)]
{
use std::os::windows::process::CommandExt;
cmd.creation_flags(winapi::um::winbase::CREATE_NO_WINDOW);
if _ui {
cmd.env(SET_FOREGROUND_WINDOW_ENV_KEY, "1");
}
}
cmd.env(APPNAME_RUNTIME_ENV_KEY, exe_name);
if use_null_stdio() {
cmd.stdin(Stdio::null())
.stdout(Stdio::null())
.stderr(Stdio::null());
} else {
cmd.stdin(Stdio::inherit())
.stdout(Stdio::inherit())
.stderr(Stdio::inherit());
}
let _child = cmd.spawn();
#[cfg(windows)]
if _ui {
match _child {
Ok(child) => unsafe {
winapi::um::winuser::AllowSetForegroundWindow(child.id() as u32);
},
Err(e) => {
eprintln!("{:?}", e);
}
}
}
}
fn main() {
let mut args = Vec::new();
let mut arg_exe = Default::default();
let mut i = 0;
for arg in std::env::args() {
if i == 0 {
arg_exe = arg.clone();
} else {
args.push(arg);
}
i += 1;
}
let click_setup = args.is_empty() && arg_exe.to_lowercase().ends_with("install.exe");
#[cfg(windows)]
let quick_support = args.is_empty() && win::is_quick_support_exe(&arg_exe);
#[cfg(not(windows))]
let quick_support = false;
let mut ui = false;
let reader = BinaryReader::default();
if let Some(exe) = setup(
reader,
None,
click_setup || args.contains(&"--silent-install".to_owned()),
&args,
&mut ui,
) {
if click_setup {
args = vec!["--install".to_owned()];
} else if quick_support {
args = vec!["--quick_support".to_owned()];
}
execute(exe, args, ui);
}
}
#[cfg(windows)]
mod win {
use std::{fs, os::windows::process::CommandExt, path::Path, process::Command};
// Used for privacy mode(magnifier impl).
pub const RUNTIME_BROKER_EXE: &'static str = "C:\\Windows\\System32\\RuntimeBroker.exe";
pub const WIN_TOPMOST_INJECTED_PROCESS_EXE: &'static str = "RuntimeBroker_rustdesk.exe";
pub(super) fn copy_runtime_broker(dir: &Path) {
let src = RUNTIME_BROKER_EXE;
let tgt = WIN_TOPMOST_INJECTED_PROCESS_EXE;
let target_file = dir.join(tgt);
if target_file.exists() {
if let (Ok(src_file), Ok(tgt_file)) = (fs::read(src), fs::read(&target_file)) {
let src_md5 = format!("{:x}", md5::compute(&src_file));
let tgt_md5 = format!("{:x}", md5::compute(&tgt_file));
if src_md5 == tgt_md5 {
return;
}
}
}
let _allow_err = Command::new("taskkill")
.args(&["/F", "/IM", "RuntimeBroker_rustdesk.exe"])
.creation_flags(winapi::um::winbase::CREATE_NO_WINDOW)
.output();
let _allow_err = std::fs::copy(src, &format!("{}\\{}", dir.to_string_lossy(), tgt));
}
/// Check if the executable is a Quick Support version.
/// Note: This function must be kept in sync with `src/core_main.rs`.
#[inline]
pub(super) fn is_quick_support_exe(exe: &str) -> bool {
let exe = exe.to_lowercase();
exe.contains("-qs-") || exe.contains("-qs.exe") || exe.contains("_qs.exe")
}
}
Binary file not shown.

After

Width:  |  Height:  |  Size: 1.2 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 58 KiB

+232
View File
@@ -0,0 +1,232 @@
use native_windows_gui as nwg;
use nwg::NativeUi;
use std::cell::RefCell;
const GIF_DATA: &[u8] = include_bytes!("./res/spin.gif");
const LABEL_DATA: &[u8] = include_bytes!("./res/label.png");
const GIF_SIZE: i32 = 32;
const BG_COLOR: [u8; 3] = [90, 90, 120];
const BORDER_COLOR: [u8; 3] = [40, 40, 40];
const GIF_DELAY: u64 = 30;
#[derive(Default)]
pub struct BasicApp {
window: nwg::Window,
border_image: nwg::ImageFrame,
bg_image: nwg::ImageFrame,
gif_image: nwg::ImageFrame,
label_image: nwg::ImageFrame,
border_layout: nwg::GridLayout,
bg_layout: nwg::GridLayout,
inner_layout: nwg::GridLayout,
timer: nwg::AnimationTimer,
decoder: nwg::ImageDecoder,
gif_index: RefCell<usize>,
gif_images: RefCell<Vec<nwg::Bitmap>>,
}
impl BasicApp {
fn exit(&self) {
self.timer.stop();
nwg::stop_thread_dispatch();
}
fn load_gif(&self) -> Result<(), nwg::NwgError> {
let image_source = self.decoder.from_stream(GIF_DATA)?;
for frame_index in 0..image_source.frame_count() {
let image_data = image_source.frame(frame_index)?;
let image_data = self
.decoder
.resize_image(&image_data, [GIF_SIZE as u32, GIF_SIZE as u32])?;
let bmp = image_data.as_bitmap()?;
self.gif_images.borrow_mut().push(bmp);
}
Ok(())
}
fn update_gif(&self) -> Result<(), nwg::NwgError> {
let images = self.gif_images.borrow();
if images.len() == 0 {
return Err(nwg::NwgError::ImageDecoderError(
-1,
"no gif images".to_string(),
));
}
let image_index = *self.gif_index.borrow() % images.len();
self.gif_image.set_bitmap(Some(&images[image_index]));
*self.gif_index.borrow_mut() = (image_index + 1) % images.len();
Ok(())
}
fn start_timer(&self) {
self.timer.start();
}
}
mod basic_app_ui {
use super::*;
use native_windows_gui::{self as nwg, Bitmap};
use nwg::{Event, GridLayoutItem};
use std::cell::RefCell;
use std::ops::Deref;
use std::rc::Rc;
pub struct BasicAppUi {
inner: Rc<BasicApp>,
default_handler: RefCell<Vec<nwg::EventHandler>>,
}
impl nwg::NativeUi<BasicAppUi> for BasicApp {
fn build_ui(mut data: BasicApp) -> Result<BasicAppUi, nwg::NwgError> {
data.decoder = nwg::ImageDecoder::new()?;
let col_cnt: i32 = 7;
let row_cnt: i32 = 3;
let border_width: i32 = 1;
let window_size = (
GIF_SIZE * col_cnt + 2 * border_width,
GIF_SIZE * row_cnt + 2 * border_width,
);
// Controls
nwg::Window::builder()
.flags(nwg::WindowFlags::POPUP | nwg::WindowFlags::VISIBLE)
.size(window_size)
.center(true)
.build(&mut data.window)?;
nwg::ImageFrame::builder()
.parent(&data.window)
.size(window_size)
.background_color(Some(BORDER_COLOR))
.build(&mut data.border_image)?;
nwg::ImageFrame::builder()
.parent(&data.border_image)
.size((row_cnt * GIF_SIZE, col_cnt * GIF_SIZE))
.background_color(Some(BG_COLOR))
.build(&mut data.bg_image)?;
nwg::ImageFrame::builder()
.parent(&data.bg_image)
.size((GIF_SIZE, GIF_SIZE))
.background_color(Some(BG_COLOR))
.build(&mut data.gif_image)?;
nwg::ImageFrame::builder()
.parent(&data.bg_image)
.background_color(Some(BG_COLOR))
.bitmap(Some(&Bitmap::from_bin(LABEL_DATA)?))
.build(&mut data.label_image)?;
nwg::AnimationTimer::builder()
.parent(&data.window)
.interval(std::time::Duration::from_millis(GIF_DELAY))
.build(&mut data.timer)?;
// Wrap-up
let ui = BasicAppUi {
inner: Rc::new(data),
default_handler: Default::default(),
};
// Layouts
nwg::GridLayout::builder()
.parent(&ui.window)
.spacing(0)
.margin([0, 0, 0, 0])
.max_column(Some(1))
.max_row(Some(1))
.child_item(GridLayoutItem::new(&ui.border_image, 0, 0, 1, 1))
.build(&ui.border_layout)?;
nwg::GridLayout::builder()
.parent(&ui.border_image)
.spacing(0)
.margin([
border_width as _,
border_width as _,
border_width as _,
border_width as _,
])
.max_column(Some(1))
.max_row(Some(1))
.child_item(GridLayoutItem::new(&ui.bg_image, 0, 0, 1, 1))
.build(&ui.bg_layout)?;
nwg::GridLayout::builder()
.parent(&ui.bg_image)
.spacing(0)
.margin([0, 0, 0, 0])
.max_column(Some(col_cnt as _))
.max_row(Some(row_cnt as _))
.child_item(GridLayoutItem::new(&ui.gif_image, 2, 1, 1, 1))
.child_item(GridLayoutItem::new(&ui.label_image, 3, 1, 3, 1))
.build(&ui.inner_layout)?;
// Events
let evt_ui = Rc::downgrade(&ui.inner);
let handle_events = move |evt, _evt_data, _handle| {
if let Some(evt_ui) = evt_ui.upgrade().as_mut() {
match evt {
Event::OnWindowClose => {
evt_ui.exit();
}
Event::OnTimerTick => {
if let Err(e) = evt_ui.update_gif() {
eprintln!("{:?}", e);
}
}
_ => {}
}
}
};
ui.default_handler
.borrow_mut()
.push(nwg::full_bind_event_handler(
&ui.window.handle,
handle_events,
));
return Ok(ui);
}
}
impl Drop for BasicAppUi {
/// To make sure that everything is freed without issues, the default handler must be unbound.
fn drop(&mut self) {
let mut handlers = self.default_handler.borrow_mut();
for handler in handlers.drain(0..) {
nwg::unbind_event_handler(&handler);
}
}
}
impl Deref for BasicAppUi {
type Target = BasicApp;
fn deref(&self) -> &BasicApp {
&self.inner
}
}
}
fn ui() -> Result<(), nwg::NwgError> {
nwg::init()?;
let app = BasicApp::build_ui(Default::default())?;
app.load_gif()?;
app.start_timer();
nwg::dispatch_thread_events();
Ok(())
}
pub fn setup() {
std::thread::spawn(move || {
if let Err(e) = ui() {
eprintln!("{:?}", e);
}
});
}
+11
View File
@@ -0,0 +1,11 @@
[package]
name = "remote_printer"
version = "0.1.0"
edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[target.'cfg(target_os = "windows")'.dependencies]
hbb_common = { version = "0.1.0", path = "../hbb_common" }
winapi = { version = "0.3" }
windows-strings = "0.3.1"
+34
View File
@@ -0,0 +1,34 @@
#[cfg(target_os = "windows")]
mod setup;
#[cfg(target_os = "windows")]
pub use setup::{
is_rd_printer_installed,
setup::{install_update_printer, uninstall_printer},
};
#[cfg(target_os = "windows")]
const RD_DRIVER_INF_PATH: &str = "drivers/RustDeskPrinterDriver/RustDeskPrinterDriver.inf";
#[cfg(target_os = "windows")]
fn get_printer_name(app_name: &str) -> Vec<u16> {
format!("{} Printer", app_name)
.encode_utf16()
.chain(Some(0))
.collect()
}
#[cfg(target_os = "windows")]
fn get_driver_name() -> Vec<u16> {
"RustDesk v4 Printer Driver"
.encode_utf16()
.chain(Some(0))
.collect()
}
#[cfg(target_os = "windows")]
fn get_port_name(app_name: &str) -> Vec<u16> {
format!("{} Printer", app_name)
.encode_utf16()
.chain(Some(0))
.collect()
}
+202
View File
@@ -0,0 +1,202 @@
use super::{common_enum, get_wstr_bytes, is_name_equal};
use hbb_common::{bail, log, ResultType};
use std::{io, ptr::null_mut, time::Duration};
use winapi::{
shared::{
minwindef::{BOOL, DWORD, FALSE, LPBYTE, LPDWORD, MAX_PATH},
ntdef::{DWORDLONG, LPCWSTR},
winerror::{ERROR_UNKNOWN_PRINTER_DRIVER, S_OK},
},
um::{
winspool::{
DeletePrinterDriverExW, DeletePrinterDriverPackageW, EnumPrinterDriversW,
InstallPrinterDriverFromPackageW, UploadPrinterDriverPackageW, DPD_DELETE_ALL_FILES,
DRIVER_INFO_6W, DRIVER_INFO_8W, IPDFP_COPY_ALL_FILES, UPDP_SILENT_UPLOAD,
UPDP_UPLOAD_ALWAYS,
},
winuser::GetForegroundWindow,
},
};
use windows_strings::PCWSTR;
const HRESULT_ERR_ELEMENT_NOT_FOUND: u32 = 0x80070490;
fn enum_printer_driver(
level: DWORD,
p_driver_info: LPBYTE,
cb_buf: DWORD,
pcb_needed: LPDWORD,
pc_returned: LPDWORD,
) -> BOOL {
unsafe {
// https://learn.microsoft.com/en-us/windows/win32/printdocs/enumprinterdrivers
// This is a blocking or synchronous function and might not return immediately.
// How quickly this function returns depends on run-time factors
// such as network status, print server configuration, and printer driver implementation factors that are difficult to predict when writing an application.
// Calling this function from a thread that manages interaction with the user interface could make the application appear to be unresponsive.
EnumPrinterDriversW(
null_mut(),
null_mut(),
level,
p_driver_info,
cb_buf,
pcb_needed,
pc_returned,
)
}
}
pub fn get_installed_driver_version(name: &PCWSTR) -> ResultType<Option<DWORDLONG>> {
common_enum(
"EnumPrinterDriversW",
enum_printer_driver,
6,
|info: &DRIVER_INFO_6W| {
if is_name_equal(name, info.pName) {
Some(info.dwlDriverVersion)
} else {
None
}
},
|| None,
)
}
fn find_inf(name: &PCWSTR) -> ResultType<Vec<u16>> {
let r = common_enum(
"EnumPrinterDriversW",
enum_printer_driver,
8,
|info: &DRIVER_INFO_8W| {
if is_name_equal(name, info.pName) {
Some(get_wstr_bytes(info.pszInfPath))
} else {
None
}
},
|| None,
)?;
Ok(r.unwrap_or(vec![]))
}
fn delete_printer_driver(name: &PCWSTR) -> ResultType<()> {
unsafe {
// If the printer is used after the spooler service is started. E.g., printing a document through RustDesk Printer.
// `DeletePrinterDriverExW()` may fail with `ERROR_PRINTER_DRIVER_IN_USE`(3001, 0xBB9).
// We can only ignore this error for now.
// Though restarting the spooler service is a solution, it's not a good idea to restart the service.
//
// Deleting the printer driver after deleting the printer is a common practice.
// No idea why `DeletePrinterDriverExW()` fails with `ERROR_UNKNOWN_PRINTER_DRIVER` after using the printer once.
// https://github.com/ChromiumWebApps/chromium/blob/c7361d39be8abd1574e6ce8957c8dbddd4c6ccf7/cloud_print/virtual_driver/win/install/setup.cc#L422
// AnyDesk printer driver and the simplest printer driver also have the same issue.
if FALSE
== DeletePrinterDriverExW(
null_mut(),
null_mut(),
name.as_ptr() as _,
DPD_DELETE_ALL_FILES,
0,
)
{
let err = io::Error::last_os_error();
if err.raw_os_error() == Some(ERROR_UNKNOWN_PRINTER_DRIVER as _) {
return Ok(());
} else {
bail!("Failed to delete the printer driver, {}", err)
}
}
}
Ok(())
}
// https://github.com/dvalter/chromium-android-ext-dev/blob/dab74f7d5bc5a8adf303090ee25c611b4d54e2db/cloud_print/virtual_driver/win/install/setup.cc#L190
fn delete_printer_driver_package(inf: Vec<u16>) -> ResultType<()> {
if inf.is_empty() {
return Ok(());
}
let slen = if inf[inf.len() - 1] == 0 {
inf.len() - 1
} else {
inf.len()
};
let inf_path = String::from_utf16_lossy(&inf[..slen]);
if !std::path::Path::new(&inf_path).exists() {
return Ok(());
}
let mut retries = 3;
loop {
unsafe {
let res = DeletePrinterDriverPackageW(null_mut(), inf.as_ptr(), null_mut());
if res == S_OK || res == HRESULT_ERR_ELEMENT_NOT_FOUND as i32 {
return Ok(());
}
log::error!("Failed to delete the printer driver, result: {}", res);
}
retries -= 1;
if retries <= 0 {
bail!("Failed to delete the printer driver");
}
std::thread::sleep(Duration::from_secs(2));
}
}
pub fn uninstall_driver(name: &PCWSTR) -> ResultType<()> {
// Note: inf must be found before `delete_printer_driver()`.
let inf = find_inf(name)?;
delete_printer_driver(name)?;
delete_printer_driver_package(inf)
}
pub fn install_driver(name: &PCWSTR, inf: LPCWSTR) -> ResultType<()> {
let mut size = (MAX_PATH * 10) as u32;
let mut package_path = [0u16; MAX_PATH * 10];
unsafe {
let mut res = UploadPrinterDriverPackageW(
null_mut(),
inf,
null_mut(),
UPDP_SILENT_UPLOAD | UPDP_UPLOAD_ALWAYS,
null_mut(),
package_path.as_mut_ptr(),
&mut size as _,
);
if res != S_OK {
log::error!(
"Failed to upload the printer driver package to the driver cache silently, {}. Will try with user UI.",
res
);
res = UploadPrinterDriverPackageW(
null_mut(),
inf,
null_mut(),
UPDP_UPLOAD_ALWAYS,
GetForegroundWindow(),
package_path.as_mut_ptr(),
&mut size as _,
);
if res != S_OK {
bail!(
"Failed to upload the printer driver package to the driver cache with UI, {}",
res
);
}
}
// https://learn.microsoft.com/en-us/windows/win32/printdocs/installprinterdriverfrompackage
res = InstallPrinterDriverFromPackageW(
null_mut(),
package_path.as_ptr(),
name.as_ptr(),
null_mut(),
IPDFP_COPY_ALL_FILES,
);
if res != S_OK {
bail!("Failed to install the printer driver from package, {}", res);
}
}
Ok(())
}
+101
View File
@@ -0,0 +1,101 @@
#![allow(non_snake_case)]
use hbb_common::{bail, ResultType};
use std::{io, ptr::null_mut};
use winapi::{
shared::{
minwindef::{BOOL, DWORD, FALSE, LPBYTE, LPDWORD},
ntdef::{LPCWSTR, LPWSTR},
},
um::winbase::{lstrcmpiW, lstrlenW},
};
use windows_strings::PCWSTR;
mod driver;
mod port;
pub(crate) mod printer;
pub(crate) mod setup;
#[inline]
pub fn is_rd_printer_installed(app_name: &str) -> ResultType<bool> {
let printer_name = crate::get_printer_name(app_name);
let rd_printer_name = PCWSTR::from_raw(printer_name.as_ptr());
printer::is_printer_added(&rd_printer_name)
}
fn get_wstr_bytes(p: LPWSTR) -> Vec<u16> {
let mut vec_bytes = vec![];
unsafe {
let len: isize = lstrlenW(p) as _;
if len > 0 {
for i in 0..len + 1 {
vec_bytes.push(*p.offset(i));
}
}
}
vec_bytes
}
fn is_name_equal(name: &PCWSTR, name_from_api: LPCWSTR) -> bool {
// https://learn.microsoft.com/en-us/windows/win32/api/winbase/nf-winbase-lstrcmpiw
// For some locales, the lstrcmpi function may be insufficient.
// If this occurs, use `CompareStringEx` to ensure proper comparison.
// For example, in Japan call with the NORM_IGNORECASE, NORM_IGNOREKANATYPE, and NORM_IGNOREWIDTH values to achieve the most appropriate non-exact string comparison.
// Note that specifying these values slows performance, so use them only when necessary.
//
// No need to consider `CompareStringEx` for now.
unsafe { lstrcmpiW(name.as_ptr(), name_from_api) == 0 }
}
fn common_enum<T, R: Sized>(
enum_name: &str,
enum_fn: fn(
Level: DWORD,
pDriverInfo: LPBYTE,
cbBuf: DWORD,
pcbNeeded: LPDWORD,
pcReturned: LPDWORD,
) -> BOOL,
level: DWORD,
on_data: impl Fn(&T) -> Option<R>,
on_no_data: impl Fn() -> Option<R>,
) -> ResultType<Option<R>> {
let mut needed = 0;
let mut returned = 0;
enum_fn(level, null_mut(), 0, &mut needed, &mut returned);
if needed == 0 {
return Ok(on_no_data());
}
let mut buffer = vec![0u8; needed as usize];
if FALSE
== enum_fn(
level,
buffer.as_mut_ptr(),
needed,
&mut needed,
&mut returned,
)
{
bail!(
"Failed to call {}, error: {}",
enum_name,
io::Error::last_os_error()
)
}
// to-do: how to free the buffers in *const T?
let p_enum_info = buffer.as_ptr() as *const T;
unsafe {
for i in 0..returned {
let enum_info = p_enum_info.offset(i as isize);
let r = on_data(&*enum_info);
if r.is_some() {
return Ok(r);
}
}
}
Ok(on_no_data())
}
+128
View File
@@ -0,0 +1,128 @@
use super::{common_enum, is_name_equal, printer::get_printer_installed_on_port};
use hbb_common::{bail, ResultType};
use std::{io, ptr::null_mut};
use winapi::{
shared::minwindef::{BOOL, DWORD, FALSE, LPBYTE, LPDWORD},
um::{
winnt::HANDLE,
winspool::{
ClosePrinter, EnumPortsW, OpenPrinterW, XcvDataW, PORT_INFO_2W, PRINTER_DEFAULTSW,
SERVER_WRITE,
},
},
};
use windows_strings::{w, PCWSTR};
const XCV_MONITOR_LOCAL_PORT: PCWSTR = w!(",XcvMonitor Local Port");
fn enum_printer_port(
level: DWORD,
p_port_info: LPBYTE,
cb_buf: DWORD,
pcb_needed: LPDWORD,
pc_returned: LPDWORD,
) -> BOOL {
unsafe {
// https://learn.microsoft.com/en-us/windows/win32/printdocs/enumports
// This is a blocking or synchronous function and might not return immediately.
// How quickly this function returns depends on run-time factors
// such as network status, print server configuration, and printer driver implementation factors that are difficult to predict when writing an application.
// Calling this function from a thread that manages interaction with the user interface could make the application appear to be unresponsive.
EnumPortsW(
null_mut(),
level,
p_port_info,
cb_buf,
pcb_needed,
pc_returned,
)
}
}
fn is_port_exists(name: &PCWSTR) -> ResultType<bool> {
let r = common_enum(
"EnumPortsW",
enum_printer_port,
2,
|info: &PORT_INFO_2W| {
if is_name_equal(name, info.pPortName) {
Some(true)
} else {
None
}
},
|| None,
)?;
Ok(r.unwrap_or(false))
}
unsafe fn execute_on_local_port(port: &PCWSTR, command: &PCWSTR) -> ResultType<()> {
let mut dft = PRINTER_DEFAULTSW {
pDataType: null_mut(),
pDevMode: null_mut(),
DesiredAccess: SERVER_WRITE,
};
let mut h_monitor: HANDLE = null_mut();
if FALSE
== OpenPrinterW(
XCV_MONITOR_LOCAL_PORT.as_ptr() as _,
&mut h_monitor,
&mut dft as *mut PRINTER_DEFAULTSW as _,
)
{
bail!(format!(
"Failed to open Local Port monitor. Error: {}",
io::Error::last_os_error()
))
}
let mut output_needed: u32 = 0;
let mut status: u32 = 0;
if FALSE
== XcvDataW(
h_monitor,
command.as_ptr(),
port.as_ptr() as *mut u8,
(port.len() + 1) as u32 * 2,
null_mut(),
0,
&mut output_needed,
&mut status,
)
{
ClosePrinter(h_monitor);
bail!(format!(
"Failed to execute the command on the printer port, Error: {}",
io::Error::last_os_error()
))
}
ClosePrinter(h_monitor);
Ok(())
}
fn add_local_port(port: &PCWSTR) -> ResultType<()> {
unsafe { execute_on_local_port(port, &w!("AddPort")) }
}
fn delete_local_port(port: &PCWSTR) -> ResultType<()> {
unsafe { execute_on_local_port(port, &w!("DeletePort")) }
}
pub fn check_add_local_port(port: &PCWSTR) -> ResultType<()> {
if !is_port_exists(port)? {
return add_local_port(port);
}
Ok(())
}
pub fn check_delete_local_port(port: &PCWSTR) -> ResultType<()> {
if is_port_exists(port)? {
if get_printer_installed_on_port(port)?.is_some() {
bail!("The printer is installed on the port. Please remove the printer first.");
}
return delete_local_port(port);
}
Ok(())
}
+161
View File
@@ -0,0 +1,161 @@
use super::{common_enum, get_wstr_bytes, is_name_equal};
use hbb_common::{bail, ResultType};
use std::{io, ptr::null_mut};
use winapi::{
shared::{
minwindef::{BOOL, DWORD, FALSE, LPBYTE, LPDWORD},
ntdef::HANDLE,
winerror::ERROR_INVALID_PRINTER_NAME,
},
um::winspool::{
AddPrinterW, ClosePrinter, DeletePrinter, EnumPrintersW, OpenPrinterW, SetPrinterW,
PRINTER_ALL_ACCESS, PRINTER_ATTRIBUTE_LOCAL, PRINTER_CONTROL_PURGE, PRINTER_DEFAULTSW,
PRINTER_ENUM_LOCAL, PRINTER_INFO_1W, PRINTER_INFO_2W,
},
};
use windows_strings::{w, PCWSTR};
fn enum_local_printer(
level: DWORD,
p_printer_info: LPBYTE,
cb_buf: DWORD,
pcb_needed: LPDWORD,
pc_returned: LPDWORD,
) -> BOOL {
unsafe {
// https://learn.microsoft.com/en-us/windows/win32/printdocs/enumprinters
// This is a blocking or synchronous function and might not return immediately.
// How quickly this function returns depends on run-time factors
// such as network status, print server configuration, and printer driver implementation factors that are difficult to predict when writing an application.
// Calling this function from a thread that manages interaction with the user interface could make the application appear to be unresponsive.
EnumPrintersW(
PRINTER_ENUM_LOCAL,
null_mut(),
level,
p_printer_info,
cb_buf,
pcb_needed,
pc_returned,
)
}
}
#[inline]
pub fn is_printer_added(name: &PCWSTR) -> ResultType<bool> {
let r = common_enum(
"EnumPrintersW",
enum_local_printer,
1,
|info: &PRINTER_INFO_1W| {
if is_name_equal(name, info.pName) {
Some(true)
} else {
None
}
},
|| None,
)?;
Ok(r.unwrap_or(false))
}
// Only return the first matched printer
pub fn get_printer_installed_on_port(port: &PCWSTR) -> ResultType<Option<Vec<u16>>> {
common_enum(
"EnumPrintersW",
enum_local_printer,
2,
|info: &PRINTER_INFO_2W| {
if is_name_equal(port, info.pPortName) {
Some(get_wstr_bytes(info.pPrinterName))
} else {
None
}
},
|| None,
)
}
pub fn add_printer(name: &PCWSTR, driver: &PCWSTR, port: &PCWSTR) -> ResultType<()> {
let mut printer_info = PRINTER_INFO_2W {
pServerName: null_mut(),
pPrinterName: name.as_ptr() as _,
pShareName: null_mut(),
pPortName: port.as_ptr() as _,
pDriverName: driver.as_ptr() as _,
pComment: null_mut(),
pLocation: null_mut(),
pDevMode: null_mut(),
pSepFile: null_mut(),
pPrintProcessor: w!("WinPrint").as_ptr() as _,
pDatatype: w!("RAW").as_ptr() as _,
pParameters: null_mut(),
pSecurityDescriptor: null_mut(),
Attributes: PRINTER_ATTRIBUTE_LOCAL,
Priority: 0,
DefaultPriority: 0,
StartTime: 0,
UntilTime: 0,
Status: 0,
cJobs: 0,
AveragePPM: 0,
};
unsafe {
let h_printer = AddPrinterW(
null_mut(),
2,
&mut printer_info as *mut PRINTER_INFO_2W as _,
);
if h_printer.is_null() {
bail!(format!(
"Failed to add printer. Error: {}",
io::Error::last_os_error()
))
}
}
Ok(())
}
pub fn delete_printer(name: &PCWSTR) -> ResultType<()> {
let mut dft = PRINTER_DEFAULTSW {
pDataType: null_mut(),
pDevMode: null_mut(),
DesiredAccess: PRINTER_ALL_ACCESS,
};
let mut h_printer: HANDLE = null_mut();
unsafe {
if FALSE
== OpenPrinterW(
name.as_ptr() as _,
&mut h_printer,
&mut dft as *mut PRINTER_DEFAULTSW as _,
)
{
let err = io::Error::last_os_error();
if err.raw_os_error() == Some(ERROR_INVALID_PRINTER_NAME as _) {
return Ok(());
} else {
bail!(format!("Failed to open printer. Error: {}", err))
}
}
if FALSE == SetPrinterW(h_printer, 0, null_mut(), PRINTER_CONTROL_PURGE) {
ClosePrinter(h_printer);
bail!(format!(
"Failed to purge printer queue. Error: {}",
io::Error::last_os_error()
))
}
if FALSE == DeletePrinter(h_printer) {
ClosePrinter(h_printer);
bail!(format!(
"Failed to delete printer. Error: {}",
io::Error::last_os_error()
))
}
ClosePrinter(h_printer);
}
Ok(())
}
+94
View File
@@ -0,0 +1,94 @@
use super::{
driver::{get_installed_driver_version, install_driver, uninstall_driver},
port::{check_add_local_port, check_delete_local_port},
printer::{add_printer, delete_printer},
};
use hbb_common::{allow_err, bail, lazy_static, log, ResultType};
use std::{path::PathBuf, sync::Mutex};
use windows_strings::PCWSTR;
lazy_static::lazy_static!(
static ref SETUP_MTX: Mutex<()> = Mutex::new(());
);
fn get_driver_inf_abs_path() -> ResultType<PathBuf> {
use crate::RD_DRIVER_INF_PATH;
let exe_file = std::env::current_exe()?;
let abs_path = match exe_file.parent() {
Some(parent) => parent.join(RD_DRIVER_INF_PATH),
None => bail!(
"Invalid exe parent for {}",
exe_file.to_string_lossy().as_ref()
),
};
if !abs_path.exists() {
bail!(
"The driver inf file \"{}\" does not exists",
RD_DRIVER_INF_PATH
)
}
Ok(abs_path)
}
// Note: This function must be called in a separate thread.
// Because many functions in this module are blocking or synchronous.
// Calling this function from a thread that manages interaction with the user interface could make the application appear to be unresponsive.
// Steps:
// 1. Add the local port.
// 2. Check if the driver is installed.
// Uninstall the existing driver if it is installed.
// We should not check the driver version because the driver is deployed with the application.
// It's better to uninstall the existing driver and install the driver from the application.
// 3. Add the printer.
pub fn install_update_printer(app_name: &str) -> ResultType<()> {
let printer_name = crate::get_printer_name(app_name);
let driver_name = crate::get_driver_name();
let port = crate::get_port_name(app_name);
let rd_printer_name = PCWSTR::from_raw(printer_name.as_ptr());
let rd_printer_driver_name = PCWSTR::from_raw(driver_name.as_ptr());
let rd_printer_port = PCWSTR::from_raw(port.as_ptr());
let inf_file = get_driver_inf_abs_path()?;
let inf_file: Vec<u16> = inf_file
.to_string_lossy()
.as_ref()
.encode_utf16()
.chain(Some(0).into_iter())
.collect();
let _lock = SETUP_MTX.lock().unwrap();
check_add_local_port(&rd_printer_port)?;
let should_install_driver = match get_installed_driver_version(&rd_printer_driver_name)? {
Some(_version) => {
delete_printer(&rd_printer_name)?;
allow_err!(uninstall_driver(&rd_printer_driver_name));
true
}
None => true,
};
if should_install_driver {
allow_err!(install_driver(&rd_printer_driver_name, inf_file.as_ptr()));
}
add_printer(&rd_printer_name, &rd_printer_driver_name, &rd_printer_port)?;
Ok(())
}
pub fn uninstall_printer(app_name: &str) {
let printer_name = crate::get_printer_name(app_name);
let driver_name = crate::get_driver_name();
let port = crate::get_port_name(app_name);
let rd_printer_name = PCWSTR::from_raw(printer_name.as_ptr());
let rd_printer_driver_name = PCWSTR::from_raw(driver_name.as_ptr());
let rd_printer_port = PCWSTR::from_raw(port.as_ptr());
let _lock = SETUP_MTX.lock().unwrap();
allow_err!(delete_printer(&rd_printer_name));
allow_err!(uninstall_driver(&rd_printer_driver_name));
allow_err!(check_delete_local_port(&rd_printer_port));
}
+4
View File
@@ -0,0 +1,4 @@
/target/
**/*.rs.bk
Cargo.lock
generated/
+68
View File
@@ -0,0 +1,68 @@
[package]
name = "scrap"
description = "Screen capture made easy."
version = "0.5.0"
repository = "https://github.com/quadrupleslap/scrap"
documentation = "https://docs.rs/scrap"
keywords = ["screen", "capture", "record"]
license = "MIT"
authors = ["Ram <quadrupleslap@gmail.com>"]
edition = "2018"
[features]
wayland = ["gstreamer", "gstreamer-app", "gstreamer-video", "dbus", "tracing", "zbus"]
mediacodec = ["ndk"]
linux-pkg-config = ["dep:pkg-config"]
hwcodec = ["dep:hwcodec"]
vram = ["hwcodec/vram"]
[dependencies]
cfg-if = "1.0"
num_cpus = "1.15"
lazy_static = "1.4"
hbb_common = { path = "../hbb_common" }
webm = { git = "https://github.com/rustdesk-org/rust-webm" }
serde = {version="1.0", features=["derive"]}
[dependencies.winapi]
version = "0.3"
default-features = true
features = ["dxgi", "dxgi1_2", "dxgi1_5", "d3d11", "winuser", "winerror", "errhandlingapi", "libloaderapi"]
[target.'cfg(target_os = "macos")'.dependencies]
block = "0.1"
[target.'cfg(target_os = "android")'.dependencies]
android_logger = "0.13"
jni = "0.21"
lazy_static = "1.4"
log = "0.4"
serde_json = "1.0"
ndk = { version = "0.7", features = ["media"], optional = true}
ndk-context = "0.1"
[target.'cfg(not(target_os = "android"))'.dev-dependencies]
repng = "0.2"
docopt = "1.1"
quest = "0.3"
[build-dependencies]
target_build_utils = "0.3"
bindgen = "0.65"
pkg-config = { version = "0.3.27", optional = true }
[target.'cfg(target_os = "linux")'.dependencies]
dbus = { version = "0.9", optional = true }
tracing = { version = "0.1", optional = true }
gstreamer = { version = "0.16", optional = true }
gstreamer-app = { version = "0.16", features = ["v1_10"], optional = true }
gstreamer-video = { version = "0.16", optional = true }
zbus = { version = "3.15", optional = true }
[dependencies.hwcodec]
git = "https://github.com/rustdesk-org/hwcodec"
optional = true
[target.'cfg(any(target_os = "windows", target_os = "linux"))'.dependencies]
nokhwa = { git = "https://github.com/rustdesk-org/nokhwa.git", branch = "fix_from_raw_parts", features = ["input-native"] }
+267
View File
@@ -0,0 +1,267 @@
use std::{
env, fs,
path::{Path, PathBuf},
println,
};
#[cfg(all(target_os = "linux", feature = "linux-pkg-config"))]
fn link_pkg_config(name: &str) -> Vec<PathBuf> {
// sometimes an override is needed
let pc_name = match name {
"libvpx" => "vpx",
_ => name,
};
let lib = pkg_config::probe_library(pc_name)
.expect(format!(
"unable to find '{pc_name}' development headers with pkg-config (feature linux-pkg-config is enabled).
try installing '{pc_name}-dev' from your system package manager.").as_str());
lib.include_paths
}
#[cfg(not(all(target_os = "linux", feature = "linux-pkg-config")))]
fn link_pkg_config(_name: &str) -> Vec<PathBuf> {
unimplemented!()
}
/// Link vcpkg package.
fn link_vcpkg(mut path: PathBuf, name: &str) -> PathBuf {
let target_os = std::env::var("CARGO_CFG_TARGET_OS").unwrap();
let mut target_arch = std::env::var("CARGO_CFG_TARGET_ARCH").unwrap();
if target_arch == "x86_64" {
target_arch = "x64".to_owned();
} else if target_arch == "x86" {
target_arch = "x86".to_owned();
} else if target_arch == "loongarch64" {
target_arch = "loongarch64".to_owned();
} else if target_arch == "aarch64" {
target_arch = "arm64".to_owned();
} else {
target_arch = "arm".to_owned();
}
let mut target = if target_os == "macos" {
if target_arch == "x64" {
"x64-osx".to_owned()
} else if target_arch == "arm64" {
"arm64-osx".to_owned()
} else {
format!("{}-{}", target_arch, target_os)
}
} else if target_os == "windows" {
"x64-windows-static".to_owned()
} else {
format!("{}-{}", target_arch, target_os)
};
if target_arch == "x86" {
target = target.replace("x64", "x86");
}
println!("cargo:info={}", target);
if let Ok(vcpkg_root) = std::env::var("VCPKG_INSTALLED_ROOT") {
path = vcpkg_root.into();
} else {
path.push("installed");
}
path.push(target);
println!(
"cargo:rustc-link-lib=static={}",
name.trim_start_matches("lib")
);
println!(
"cargo:rustc-link-search={}",
path.join("lib").to_str().unwrap()
);
let include = path.join("include");
println!("cargo:include={}", include.to_str().unwrap());
include
}
/// Link homebrew package(for Mac M1).
fn link_homebrew_m1(name: &str) -> PathBuf {
let target_os = std::env::var("CARGO_CFG_TARGET_OS").unwrap();
let target_arch = std::env::var("CARGO_CFG_TARGET_ARCH").unwrap();
if target_os != "macos" || target_arch != "aarch64" {
panic!("Couldn't find VCPKG_ROOT, also can't fallback to homebrew because it's only for macos aarch64.");
}
let mut path = PathBuf::from("/opt/homebrew/Cellar");
path.push(name);
let entries = if let Ok(dir) = std::fs::read_dir(&path) {
dir
} else {
panic!("Could not find package in {}. Make sure your homebrew and package {} are all installed.", path.to_str().unwrap(),&name);
};
let mut directories = entries
.into_iter()
.filter(|x| x.is_ok())
.map(|x| x.unwrap().path())
.filter(|x| x.is_dir())
.collect::<Vec<_>>();
// Find the newest version.
directories.sort_unstable();
if directories.is_empty() {
panic!(
"There's no installed version of {} in /opt/homebrew/Cellar",
name
);
}
path.push(directories.pop().unwrap());
// Link the library.
println!(
"cargo:rustc-link-lib=static={}",
name.trim_start_matches("lib")
);
// Add the library path.
println!(
"cargo:rustc-link-search={}",
path.join("lib").to_str().unwrap()
);
// Add the include path.
let include = path.join("include");
println!("cargo:include={}", include.to_str().unwrap());
include
}
/// Find package. By default, it will try to find vcpkg first, then homebrew(currently only for Mac M1).
/// If building for linux and feature "linux-pkg-config" is enabled, will try to use pkg-config
/// unless check fails (e.g. NO_PKG_CONFIG_libyuv=1)
fn find_package(name: &str) -> Vec<PathBuf> {
let no_pkg_config_var_name = format!("NO_PKG_CONFIG_{name}");
println!("cargo:rerun-if-env-changed={no_pkg_config_var_name}");
if cfg!(all(target_os = "linux", feature = "linux-pkg-config"))
&& std::env::var(no_pkg_config_var_name).as_deref() != Ok("1")
{
link_pkg_config(name)
} else if let Ok(vcpkg_root) = std::env::var("VCPKG_ROOT") {
vec![link_vcpkg(vcpkg_root.into(), name)]
} else {
// Try using homebrew
vec![link_homebrew_m1(name)]
}
}
fn generate_bindings(
ffi_header: &Path,
include_paths: &[PathBuf],
ffi_rs: &Path,
exact_file: &Path,
regex: &str,
) {
let mut b = bindgen::builder()
.header(ffi_header.to_str().unwrap())
.allowlist_type(regex)
.allowlist_var(regex)
.allowlist_function(regex)
.rustified_enum(regex)
.trust_clang_mangling(false)
.layout_tests(false) // breaks 32/64-bit compat
.generate_comments(false); // comments have prefix /*!\
for dir in include_paths {
b = b.clang_arg(format!("-I{}", dir.display()));
}
b.generate().unwrap().write_to_file(ffi_rs).unwrap();
fs::copy(ffi_rs, exact_file).ok(); // ignore failure
}
fn gen_vcpkg_package(package: &str, ffi_header: &str, generated: &str, regex: &str) {
let includes = find_package(package);
let src_dir = env::var_os("CARGO_MANIFEST_DIR").unwrap();
let src_dir = Path::new(&src_dir);
let out_dir = env::var_os("OUT_DIR").unwrap();
let out_dir = Path::new(&out_dir);
let ffi_header = src_dir.join("src").join("bindings").join(ffi_header);
println!("rerun-if-changed={}", ffi_header.display());
for dir in &includes {
println!("rerun-if-changed={}", dir.display());
}
let ffi_rs = out_dir.join(generated);
let exact_file = src_dir.join("generated").join(generated);
generate_bindings(&ffi_header, &includes, &ffi_rs, &exact_file, regex);
}
// If you have problems installing ffmpeg, you can download $VCPKG_ROOT/installed from ci
// Linux require link in hwcodec
/*
fn ffmpeg() {
// ffmpeg
let target_os = std::env::var("CARGO_CFG_TARGET_OS").unwrap();
let target_arch = std::env::var("CARGO_CFG_TARGET_ARCH").unwrap();
let static_libs = vec!["avcodec", "avutil", "avformat"];
static_libs.iter().for_each(|lib| {
find_package(lib);
});
if target_os == "windows" {
println!("cargo:rustc-link-lib=static=libmfx");
}
// os
let dyn_libs: Vec<&str> = if target_os == "windows" {
["User32", "bcrypt", "ole32", "advapi32"].to_vec()
} else if target_os == "linux" {
let mut v = ["va", "va-drm", "va-x11", "vdpau", "X11", "stdc++"].to_vec();
if target_arch == "x86_64" {
v.push("z");
}
v
} else if target_os == "macos" || target_os == "ios" {
["c++", "m"].to_vec()
} else if target_os == "android" {
["z", "m", "android", "atomic"].to_vec()
} else {
panic!("unsupported os");
};
dyn_libs
.iter()
.map(|lib| println!("cargo:rustc-link-lib={}", lib))
.count();
if target_os == "macos" || target_os == "ios" {
println!("cargo:rustc-link-lib=framework=CoreFoundation");
println!("cargo:rustc-link-lib=framework=CoreVideo");
println!("cargo:rustc-link-lib=framework=CoreMedia");
println!("cargo:rustc-link-lib=framework=VideoToolbox");
println!("cargo:rustc-link-lib=framework=AVFoundation");
}
}
*/
fn main() {
// in this crate, these are also valid configurations
println!("cargo:rustc-check-cfg=cfg(dxgi,quartz,x11)");
// there is problem with cfg(target_os) in build.rs, so use our workaround
let target_os = std::env::var("CARGO_CFG_TARGET_OS").unwrap();
// note: all link symbol names in x86 (32-bit) are prefixed wth "_".
// run "rustup show" to show current default toolchain, if it is stable-x86-pc-windows-msvc,
// please install x64 toolchain by "rustup toolchain install stable-x86_64-pc-windows-msvc",
// then set x64 to default by "rustup default stable-x86_64-pc-windows-msvc"
let target = target_build_utils::TargetInfo::new();
if target.unwrap().target_pointer_width() != "64" {
// panic!("Only support 64bit system");
}
env::remove_var("CARGO_CFG_TARGET_FEATURE");
env::set_var("CARGO_CFG_TARGET_FEATURE", "crt-static");
find_package("libyuv");
gen_vcpkg_package("libvpx", "vpx_ffi.h", "vpx_ffi.rs", "^[vV].*");
gen_vcpkg_package("aom", "aom_ffi.h", "aom_ffi.rs", "^(aom|AOM|OBU|AV1).*");
gen_vcpkg_package("libyuv", "yuv_ffi.h", "yuv_ffi.rs", ".*");
// ffmpeg();
if target_os == "ios" {
// nothing
} else if target_os == "android" {
println!("cargo:rustc-cfg=android");
} else if cfg!(windows) {
// The first choice is Windows because DXGI is amazing.
println!("cargo:rustc-cfg=dxgi");
} else if cfg!(target_os = "macos") {
// Quartz is second because macOS is the (annoying) exception.
println!("cargo:rustc-cfg=quartz");
} else if cfg!(unix) {
// On UNIX we pray that X11 (with XCB) is available.
println!("cargo:rustc-cfg=x11");
}
}
+511
View File
@@ -0,0 +1,511 @@
use jni::objects::JByteBuffer;
use jni::objects::JString;
use jni::objects::JValue;
use jni::sys::jboolean;
use jni::JNIEnv;
use jni::{
objects::{GlobalRef, JClass, JObject},
strings::JNIString,
JavaVM,
};
use hbb_common::{message_proto::MultiClipboards, protobuf::Message};
use jni::errors::{Error as JniError, Result as JniResult};
use lazy_static::lazy_static;
use serde::Deserialize;
use std::ops::Not;
use std::os::raw::c_void;
use std::sync::atomic::{AtomicPtr, Ordering::SeqCst};
use std::sync::{Mutex, RwLock};
use std::time::{Duration, Instant};
lazy_static! {
static ref JVM: RwLock<Option<JavaVM>> = RwLock::new(None);
static ref MAIN_SERVICE_CTX: RwLock<Option<GlobalRef>> = RwLock::new(None); // MainService -> video service / audio service / info
static ref APPLICATION_CONTEXT: RwLock<Option<GlobalRef>> = RwLock::new(None);
static ref VIDEO_RAW: Mutex<FrameRaw> = Mutex::new(FrameRaw::new("video", MAX_VIDEO_FRAME_TIMEOUT));
static ref AUDIO_RAW: Mutex<FrameRaw> = Mutex::new(FrameRaw::new("audio", MAX_AUDIO_FRAME_TIMEOUT));
static ref NDK_CONTEXT_INITED: Mutex<bool> = Default::default();
static ref MEDIA_CODEC_INFOS: RwLock<Option<MediaCodecInfos>> = RwLock::new(None);
static ref CLIPBOARD_MANAGER: RwLock<Option<GlobalRef>> = RwLock::new(None);
static ref CLIPBOARDS_HOST: Mutex<Option<MultiClipboards>> = Mutex::new(None);
static ref CLIPBOARDS_CLIENT: Mutex<Option<MultiClipboards>> = Mutex::new(None);
}
const MAX_VIDEO_FRAME_TIMEOUT: Duration = Duration::from_millis(100);
const MAX_AUDIO_FRAME_TIMEOUT: Duration = Duration::from_millis(1000);
struct FrameRaw {
name: &'static str,
ptr: AtomicPtr<u8>,
len: usize,
last_update: Instant,
timeout: Duration,
enable: bool,
}
impl FrameRaw {
fn new(name: &'static str, timeout: Duration) -> Self {
FrameRaw {
name,
ptr: AtomicPtr::default(),
len: 0,
last_update: Instant::now(),
timeout,
enable: false,
}
}
fn set_enable(&mut self, value: bool) {
self.enable = value;
self.ptr.store(std::ptr::null_mut(), SeqCst);
self.len = 0;
}
fn update(&mut self, data: *mut u8, len: usize) {
if self.enable.not() {
return;
}
self.len = len;
self.ptr.store(data, SeqCst);
self.last_update = Instant::now();
}
// take inner data as slice
// release when success
fn take<'a>(&mut self, dst: &mut Vec<u8>, last: &mut Vec<u8>) -> Option<()> {
if self.enable.not() {
return None;
}
let ptr = self.ptr.load(SeqCst);
if ptr.is_null() || self.len == 0 {
None
} else {
if self.last_update.elapsed() > self.timeout {
log::trace!("Failed to take {} raw,timeout!", self.name);
return None;
}
let slice = unsafe { std::slice::from_raw_parts(ptr, self.len) };
self.release();
if last.len() == slice.len() && crate::would_block_if_equal(last, slice).is_err() {
return None;
}
dst.resize(slice.len(), 0);
unsafe {
std::ptr::copy_nonoverlapping(slice.as_ptr(), dst.as_mut_ptr(), slice.len());
}
Some(())
}
}
fn release(&mut self) {
self.len = 0;
self.ptr.store(std::ptr::null_mut(), SeqCst);
}
}
pub fn get_video_raw<'a>(dst: &mut Vec<u8>, last: &mut Vec<u8>) -> Option<()> {
VIDEO_RAW.lock().ok()?.take(dst, last)
}
pub fn get_audio_raw<'a>(dst: &mut Vec<u8>, last: &mut Vec<u8>) -> Option<()> {
AUDIO_RAW.lock().ok()?.take(dst, last)
}
pub fn get_clipboards(client: bool) -> Option<MultiClipboards> {
if client {
CLIPBOARDS_CLIENT.lock().ok()?.take()
} else {
CLIPBOARDS_HOST.lock().ok()?.take()
}
}
#[no_mangle]
pub extern "system" fn Java_ffi_FFI_onVideoFrameUpdate(
env: JNIEnv,
_class: JClass,
buffer: JObject,
) {
let jb = JByteBuffer::from(buffer);
if let Ok(data) = env.get_direct_buffer_address(&jb) {
if let Ok(len) = env.get_direct_buffer_capacity(&jb) {
VIDEO_RAW.lock().unwrap().update(data, len);
}
}
}
#[no_mangle]
pub extern "system" fn Java_ffi_FFI_onAudioFrameUpdate(
env: JNIEnv,
_class: JClass,
buffer: JObject,
) {
let jb = JByteBuffer::from(buffer);
if let Ok(data) = env.get_direct_buffer_address(&jb) {
if let Ok(len) = env.get_direct_buffer_capacity(&jb) {
AUDIO_RAW.lock().unwrap().update(data, len);
}
}
}
#[no_mangle]
pub extern "system" fn Java_ffi_FFI_onClipboardUpdate(
env: JNIEnv,
_class: JClass,
buffer: JByteBuffer,
) {
if let Ok(data) = env.get_direct_buffer_address(&buffer) {
if let Ok(len) = env.get_direct_buffer_capacity(&buffer) {
let data = unsafe { std::slice::from_raw_parts(data, len) };
if let Ok(clips) = MultiClipboards::parse_from_bytes(&data[1..]) {
let is_client = data[0] == 1;
if is_client {
*CLIPBOARDS_CLIENT.lock().unwrap() = Some(clips);
} else {
*CLIPBOARDS_HOST.lock().unwrap() = Some(clips);
}
}
}
}
}
#[no_mangle]
pub extern "system" fn Java_ffi_FFI_setFrameRawEnable(
env: JNIEnv,
_class: JClass,
name: JString,
value: jboolean,
) {
let mut env = env;
if let Ok(name) = env.get_string(&name) {
let name: String = name.into();
let value = value.eq(&1);
if name.eq("video") {
VIDEO_RAW.lock().unwrap().set_enable(value);
} else if name.eq("audio") {
AUDIO_RAW.lock().unwrap().set_enable(value);
}
};
}
#[no_mangle]
pub extern "system" fn Java_ffi_FFI_init(env: JNIEnv, _class: JClass, ctx: JObject) {
log::debug!("MainService init from java");
if let Ok(jvm) = env.get_java_vm() {
let java_vm = jvm.get_java_vm_pointer() as *mut c_void;
let mut jvm_lock = JVM.write().unwrap();
if jvm_lock.is_none() {
*jvm_lock = Some(jvm);
}
drop(jvm_lock);
if let Ok(context) = env.new_global_ref(ctx) {
let context_jobject = context.as_obj().as_raw() as *mut c_void;
*MAIN_SERVICE_CTX.write().unwrap() = Some(context);
init_ndk_context(java_vm, context_jobject);
}
}
}
#[no_mangle]
pub extern "system" fn Java_ffi_FFI_setClipboardManager(
env: JNIEnv,
_class: JClass,
clipboard_manager: JObject,
) {
log::debug!("ClipboardManager init from java");
if let Ok(jvm) = env.get_java_vm() {
let java_vm = jvm.get_java_vm_pointer() as *mut c_void;
let mut jvm_lock = JVM.write().unwrap();
if jvm_lock.is_none() {
*jvm_lock = Some(jvm);
}
drop(jvm_lock);
if let Ok(manager) = env.new_global_ref(clipboard_manager) {
*CLIPBOARD_MANAGER.write().unwrap() = Some(manager);
}
}
}
#[derive(Debug, Deserialize, Clone)]
pub struct MediaCodecInfo {
pub name: String,
pub is_encoder: bool,
#[serde(default)]
pub hw: Option<bool>, // api 29+
pub mime_type: String,
pub surface: bool,
pub nv12: bool,
#[serde(default)]
pub low_latency: Option<bool>, // api 30+, decoder
pub min_bitrate: u32,
pub max_bitrate: u32,
pub min_width: usize,
pub max_width: usize,
pub min_height: usize,
pub max_height: usize,
}
#[derive(Debug, Deserialize, Clone)]
pub struct MediaCodecInfos {
pub version: usize,
pub w: usize, // aligned
pub h: usize, // aligned
pub codecs: Vec<MediaCodecInfo>,
}
#[no_mangle]
pub extern "system" fn Java_ffi_FFI_setCodecInfo(env: JNIEnv, _class: JClass, info: JString) {
let mut env = env;
if let Ok(info) = env.get_string(&info) {
let info: String = info.into();
if let Ok(infos) = serde_json::from_str::<MediaCodecInfos>(&info) {
*MEDIA_CODEC_INFOS.write().unwrap() = Some(infos);
}
}
}
pub fn get_codec_info() -> Option<MediaCodecInfos> {
MEDIA_CODEC_INFOS.read().unwrap().as_ref().cloned()
}
pub fn clear_codec_info() {
*MEDIA_CODEC_INFOS.write().unwrap() = None;
}
// another way to fix "reference table overflow" error caused by new_string and call_main_service_pointer_input frequently calld
// is below, but here I change kind from string to int for performance
/*
env.with_local_frame(10, || {
let kind = env.new_string(kind)?;
env.call_method(
ctx,
"rustPointerInput",
"(Ljava/lang/String;III)V",
&[
JValue::Object(&JObject::from(kind)),
JValue::Int(mask),
JValue::Int(x),
JValue::Int(y),
],
)?;
Ok(JObject::null())
})?;
*/
pub fn call_main_service_pointer_input(kind: &str, mask: i32, x: i32, y: i32) -> JniResult<()> {
if let (Some(jvm), Some(ctx)) = (
JVM.read().unwrap().as_ref(),
MAIN_SERVICE_CTX.read().unwrap().as_ref(),
) {
let mut env = jvm.attach_current_thread_as_daemon()?;
let kind = if kind == "touch" { 0 } else { 1 };
env.call_method(
ctx,
"rustPointerInput",
"(IIII)V",
&[
JValue::Int(kind),
JValue::Int(mask),
JValue::Int(x),
JValue::Int(y),
],
)?;
return Ok(());
} else {
return Err(JniError::ThrowFailed(-1));
}
}
pub fn call_main_service_key_event(data: &[u8]) -> JniResult<()> {
if let (Some(jvm), Some(ctx)) = (
JVM.read().unwrap().as_ref(),
MAIN_SERVICE_CTX.read().unwrap().as_ref(),
) {
let mut env = jvm.attach_current_thread_as_daemon()?;
let data = env.byte_array_from_slice(data)?;
env.call_method(
ctx,
"rustKeyEventInput",
"([B)V",
&[JValue::Object(&JObject::from(data))],
)?;
return Ok(());
} else {
return Err(JniError::ThrowFailed(-1));
}
}
fn _call_clipboard_manager<S, T>(name: S, sig: T, args: &[JValue]) -> JniResult<()>
where
S: Into<JNIString>,
T: Into<JNIString> + AsRef<str>,
{
if let (Some(jvm), Some(cm)) = (
JVM.read().unwrap().as_ref(),
CLIPBOARD_MANAGER.read().unwrap().as_ref(),
) {
let mut env = jvm.attach_current_thread()?;
env.call_method(cm, name, sig, args)?;
return Ok(());
} else {
return Err(JniError::ThrowFailed(-1));
}
}
pub fn call_clipboard_manager_update_clipboard(data: &[u8]) -> JniResult<()> {
if let (Some(jvm), Some(cm)) = (
JVM.read().unwrap().as_ref(),
CLIPBOARD_MANAGER.read().unwrap().as_ref(),
) {
let mut env = jvm.attach_current_thread()?;
let data = env.byte_array_from_slice(data)?;
env.call_method(
cm,
"rustUpdateClipboard",
"([B)V",
&[JValue::Object(&JObject::from(data))],
)?;
return Ok(());
} else {
return Err(JniError::ThrowFailed(-1));
}
}
pub fn call_clipboard_manager_enable_client_clipboard(enable: bool) -> JniResult<()> {
_call_clipboard_manager(
"rustEnableClientClipboard",
"(Z)V",
&[JValue::Bool(jboolean::from(enable))],
)
}
pub fn call_main_service_get_by_name(name: &str) -> JniResult<String> {
if let (Some(jvm), Some(ctx)) = (
JVM.read().unwrap().as_ref(),
MAIN_SERVICE_CTX.read().unwrap().as_ref(),
) {
let mut env = jvm.attach_current_thread_as_daemon()?;
let res = env.with_local_frame(10, |env| -> JniResult<String> {
let name = env.new_string(name)?;
let res = env
.call_method(
ctx,
"rustGetByName",
"(Ljava/lang/String;)Ljava/lang/String;",
&[JValue::Object(&JObject::from(name))],
)?
.l()?;
let res = JString::from(res);
let res = env.get_string(&res)?;
let res = res.to_string_lossy().to_string();
Ok(res)
})?;
Ok(res)
} else {
return Err(JniError::ThrowFailed(-1));
}
}
pub fn call_main_service_set_by_name(
name: &str,
arg1: Option<&str>,
arg2: Option<&str>,
) -> JniResult<()> {
if let (Some(jvm), Some(ctx)) = (
JVM.read().unwrap().as_ref(),
MAIN_SERVICE_CTX.read().unwrap().as_ref(),
) {
let mut env = jvm.attach_current_thread_as_daemon()?;
env.with_local_frame(10, |env| -> JniResult<()> {
let name = env.new_string(name)?;
let arg1 = env.new_string(arg1.unwrap_or(""))?;
let arg2 = env.new_string(arg2.unwrap_or(""))?;
env.call_method(
ctx,
"rustSetByName",
"(Ljava/lang/String;Ljava/lang/String;Ljava/lang/String;)V",
&[
JValue::Object(&JObject::from(name)),
JValue::Object(&JObject::from(arg1)),
JValue::Object(&JObject::from(arg2)),
],
)?;
Ok(())
})?;
return Ok(());
} else {
return Err(JniError::ThrowFailed(-1));
}
}
// Difference between MainService, MainActivity, JNI_OnLoad:
// jvm is the same, ctx is differen and ctx of JNI_OnLoad is null.
// cpal: all three works
// Service(GetByName, ...): only ctx from MainService works, so use 2 init context functions
// On app start: JNI_OnLoad or MainActivity init context
// On service start first time: MainService replace the context
fn init_ndk_context(java_vm: *mut c_void, context_jobject: *mut c_void) {
let mut lock = NDK_CONTEXT_INITED.lock().unwrap();
if *lock {
unsafe {
ndk_context::release_android_context();
}
*lock = false;
}
unsafe {
ndk_context::initialize_android_context(java_vm, context_jobject);
#[cfg(feature = "hwcodec")]
hwcodec::android::ffmpeg_set_java_vm(java_vm);
}
*lock = true;
}
fn try_init_rustls_platform_verifier(env: &mut JNIEnv, context_jobject: *mut c_void) {
use hbb_common::config::ANDROID_RUSTLS_PLATFORM_VERIFIER_INITIALIZED as INITIALIZED;
use std::sync::atomic::Ordering;
let initialized = INITIALIZED.load(Ordering::Relaxed);
if !initialized {
let ctx_for_rustls = unsafe { JObject::from_raw(context_jobject as jni::sys::jobject) };
if let Err(e) =
hbb_common::rustls_platform_verifier::android::init_hosted(env, ctx_for_rustls)
{
log::error!("Failed to initialize rustls-platform-verifier: {:?}", e);
} else {
INITIALIZED.store(true, Ordering::Relaxed);
log::info!("rustls-platform-verifier initialized successfully");
}
}
}
// https://cjycode.com/flutter_rust_bridge/guides/how-to/ndk-init
#[no_mangle]
pub extern "C" fn JNI_OnLoad(vm: jni::JavaVM, res: *mut std::os::raw::c_void) -> jni::sys::jint {
if let Ok(env) = vm.get_env() {
let vm = vm.get_java_vm_pointer() as *mut std::os::raw::c_void;
init_ndk_context(vm, res);
}
jni::JNIVersion::V6.into()
}
#[no_mangle]
pub extern "system" fn Java_ffi_FFI_onAppStart(mut env: JNIEnv, _class: JClass, ctx: JObject) {
if ctx.is_null() {
log::error!("application context is null");
return;
}
if APPLICATION_CONTEXT.read().unwrap().is_some() {
log::info!("application context already initialized");
return;
}
if let Ok(jvm) = env.get_java_vm() {
if let Ok(context) = env.new_global_ref(ctx) {
let java_vm = jvm.get_java_vm_pointer() as *mut c_void;
let context_jobject = context.as_obj().as_raw() as *mut c_void;
*APPLICATION_CONTEXT.write().unwrap() = Some(context);
try_init_rustls_platform_verifier(&mut env, context_jobject);
}
}
}
+3
View File
@@ -0,0 +1,3 @@
pub mod ffi;
pub use ffi::*;
+10
View File
@@ -0,0 +1,10 @@
#include <aom/aom.h>
#include <aom/aom_image.h>
#include <aom/aom_integer.h>
#include <aom/aom_codec.h>
#include <aom/aom_external_partition.h>
#include <aom/aom_frame_buffer.h>
#include <aom/aom_encoder.h>
#include <aom/aom_decoder.h>
#include <aom/aomcx.h>
#include <aom/aomdx.h>
+9
View File
@@ -0,0 +1,9 @@
#include <vpx/vp8.h>
#include <vpx/vp8cx.h>
#include <vpx/vp8dx.h>
#include <vpx/vpx_codec.h>
#include <vpx/vpx_decoder.h>
#include <vpx/vpx_encoder.h>
#include <vpx/vpx_frame_buffer.h>
#include <vpx/vpx_image.h>
#include <vpx/vpx_integer.h>
+6
View File
@@ -0,0 +1,6 @@
#include <libyuv/convert.h>
#include <libyuv/convert_argb.h>
#include <libyuv/convert_from.h>
#include <libyuv/convert_from_argb.h>
#include <libyuv/rotate.h>
#include <libyuv/rotate_argb.h>

Some files were not shown because too many files have changed in this diff Show More