summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorLinnnus <[email protected]>2024-10-01 22:29:21 +0200
committerLinnnus <[email protected]>2024-10-01 22:29:21 +0200
commitf5c119fdaf444fc0b1b59c1e07fd32a5f6ddf069 (patch)
treefd63a803e3ffa3b9e1c05549ae29eeace7adbcf0 /src
Initial commit
Diffstat (limited to 'src')
-rw-r--r--src/config.rs251
-rw-r--r--src/main.rs97
-rw-r--r--src/service.rs157
-rw-r--r--src/systemd_socket.rs473
4 files changed, 978 insertions, 0 deletions
diff --git a/src/config.rs b/src/config.rs
new file mode 100644
index 0000000..e70043e
--- /dev/null
+++ b/src/config.rs
@@ -0,0 +1,251 @@
+use std::path::{Path, PathBuf};
+use std::fs::{self, File};
+use std::io;
+use std::fmt::{self, Display};
+use serde::Deserialize;
+
+/// All the application configuration is stored in this structure.
+#[derive(PartialEq, Clone, Debug)]
+pub struct Config {
+ /// The secret string shared with GitHub that is used to verify signed requests.
+ pub secret: String,
+
+ /// Event-command pairs. Each element of this array should be matched (and optionally executed)
+ /// against the commands in gaide.
+ pub commands: Vec<Command>,
+}
+
+impl Config {
+ pub fn from_path<P: AsRef<Path>>(path: P) -> Result<Self, ConfigError> {
+ let raw_config = RawConfig::from_path(path)?;
+ let secret = fs::read_to_string(raw_config.secret_file)?;
+ Ok(Config {
+ secret,
+ commands: raw_config.commands,
+ })
+ }
+}
+
+/// This struct reflects the actual JSON on disk. It is further processed before being returned to
+/// the rest of the application.
+#[derive(Deserialize, Clone, Debug, PartialEq)]
+struct RawConfig {
+ /// Path to file containing the secret that was shared with GitHub.
+ secret_file: PathBuf,
+
+ /// Event-command pairs.
+ commands: Vec<Command>,
+}
+
+/// Represents an event-command pair. The command is run whenever the given event is received from
+/// GitHub's API.
+#[derive(Deserialize, Clone, Debug, PartialEq)]
+pub struct Command {
+ /// The name of an event from the GitHub API. A full list of events can be found in [GitHub's
+ /// documenation][gh-events].
+ ///
+ /// [gh-events]: https://docs.github.com/en/webhooks/webhook-events-and-payloads
+ pub event: String,
+
+ /// Path to the program to be executed when [`event`](event) occurs.
+ pub command: String,
+
+ /// Additional arguments to bass to [`command`](command).
+ #[serde(default)]
+ pub args: Vec<String>,
+}
+
+/*
+/// Serde helper which disallows empty strings for [`PathBuf`s](std::path::PathBuf). Based on [this
+/// StackOverflow post][so].
+///
+/// [so]: https://stackoverflow.com/a/46755370
+fn string_as_nonempty_pathbuf<'de, D>(deserializer: D) -> Result<PathBuf, D::Error>
+where
+ D: Deserializer<'de>
+{
+ let raw: &str = Deserialize::deserialize(deserializer)?;
+ if raw.is_empty() {
+ Err(de::Error::custom("path cannot be empty"))
+ } else {
+ Ok(PathBuf::from(raw))
+ }
+}
+*/
+
+/// Errors that can occur when reading configuration.
+#[derive(Debug)]
+pub enum ConfigError {
+ /// An IO error occured, such as failing to read the file.
+ Io(io::Error),
+ /// Decoding the file failed, e.g. if JSON is missing comma.
+ SerdeError(serde_json::Error),
+}
+
+impl From<io::Error> for ConfigError {
+ fn from(e: io::Error) -> ConfigError {
+ ConfigError::Io(e)
+ }
+}
+
+impl From<serde_json::Error> for ConfigError {
+ fn from(e: serde_json::Error) -> ConfigError {
+ use serde_json::error::Category;
+ match e.classify() {
+ Category::Io => ConfigError::Io(e.into()),
+ _ => ConfigError::SerdeError(e),
+ }
+ }
+}
+
+impl Display for ConfigError {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
+ match self {
+ ConfigError::Io(e) => write!(f, "io error: {}", e),
+ ConfigError::SerdeError(e) => write!(f, "decoding error: {}", e),
+ }
+ }
+}
+
+impl RawConfig {
+ pub fn from_path<P: AsRef<Path>>(path: P) -> Result<Self, ConfigError> {
+ let file = File::open(path.as_ref())?;
+ let config: Self = serde_json::from_reader(file)?;
+ config.validate()?;
+ Ok(config)
+ }
+
+ #[allow(dead_code)] // Useful for tests.
+ pub(self) fn from_str(s: &str) -> Result<Self, ConfigError> {
+ let config: Self = serde_json::from_str(s)?;
+ config.validate()?;
+ Ok(config)
+ }
+
+ fn validate(&self) -> Result<(), ConfigError> {
+ if self.secret_file.is_relative() {
+ eprintln!("warning: configuration key `.secret_file` is relative path. This will be resolved relative to server's CWD at runtime which is most likely not what you want.");
+ // " <- Fix shitty Vim syntax highlighting
+ }
+ Ok(())
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::{Config, Command, RawConfig, ConfigError};
+ use std::path::Path;
+
+ macro_rules! assert_matches {
+ ( $e:expr , $pat:pat ) => {
+ assert_matches!($e, $pat => ())
+ };
+ ( $e:expr , $pat:pat => $c:expr ) => {
+ match $e {
+ $pat => $c,
+ ref e => panic!("assertion failed: `{:?}` does not match `{}`",
+ e, stringify!($pat))
+ }
+ };
+ }
+
+ macro_rules! assert_contains {
+ ( $a:expr , $b:expr ) => {
+ let a_string: String = $a.to_string();
+ let b_string: String = $b.to_string();
+
+ if !a_string.contains(&b_string) {
+ panic!("assertion failed: expected {:?} to contain {:?}", a_string, b_string)
+ }
+ };
+ }
+
+ #[test]
+ fn load_valid_raw_config() {
+ let config_json = r#"
+ {
+ "secret_file": "/path/to/secret.txt",
+
+ "commands": [
+ {
+ "event": "ping",
+ "command": "/usr/bin/handle-ping",
+ "args": []
+ }
+ ]
+ }
+ "#;
+ let parsed_config = RawConfig::from_str(config_json).expect("valid config");
+ let expected_config = RawConfig {
+ secret_file: Path::new("/path/to/secret.txt").to_path_buf(),
+ commands: vec![
+ Command {
+ event: "ping".to_string(),
+ command: "/usr/bin/handle-ping".to_string(),
+ args: vec![],
+ },
+ ],
+ };
+ assert_eq!(parsed_config, expected_config);
+ }
+
+ #[test]
+ fn args_are_optional() {
+ let command_json = r#"
+ {
+ "event": "ping",
+ "command": "/usr/bin/handle-ping"
+ }
+ "#;
+ let parsed_command: Command = serde_json::from_str(command_json)
+ .expect("valid configuration");
+ let expected_command = Command {
+ event: "ping".to_string(),
+ command: "/usr/bin/handle-ping".to_string(),
+ args: vec![],
+ };
+ assert_eq!(expected_command, parsed_command);
+ }
+
+ #[test]
+ fn invalid_json_gives_error() {
+ // This JSON has a trailing comma, which isn't allowed.
+ let config_json = r#"
+ {
+ "secret_file": "blah",
+ "commands": [],
+ }
+ "#;
+ let result = RawConfig::from_str(config_json);
+ let err = assert_matches!(result, Err(ConfigError::SerdeError(e)) => e);
+ assert_eq!(err.line(), 5);
+ assert_eq!(err.column(), 13);
+ assert!(err.is_syntax());
+ }
+
+ #[test]
+ fn read_valid_config() {
+ let parse_result = Config::from_path("examples/config.json");
+ let parsed_config = assert_matches!(parse_result, Ok(c @ Config { .. }) => c);
+ let expected_config = Config {
+ secret: "mysecret".to_string(),
+ commands: vec![
+ Command {
+ event: "ping".to_string(),
+ command: "/bin/echo".to_string(),
+ args: vec![
+ "Got ping event!!".to_string()
+ ],
+ },
+ Command {
+ event: "push".to_string(),
+ command: "/bin/echo".to_string(),
+ args: vec![
+ "Got push event!!".to_string()
+ ],
+ },
+ ],
+ };
+ assert_eq!(parsed_config, expected_config);
+ }
+}
diff --git a/src/main.rs b/src/main.rs
new file mode 100644
index 0000000..c09aa46
--- /dev/null
+++ b/src/main.rs
@@ -0,0 +1,97 @@
+// The systemd_socket module contains a lot of dead code which is only used in tests, but which I
+// would like to keep up to date in case I need the module for another project.
+#[allow(dead_code)]
+
+mod systemd_socket;
+mod service;
+mod config;
+
+use hyper::Request;
+use hyper::server::conn::http1;
+use hyper::service::service_fn;
+use hyper_util::rt::TokioIo;
+
+use std::os::unix::net::UnixListener as StdUnixListener;
+use tokio::net::UnixListener as TokioUnixListener;
+use std::io;
+use std::process;
+use std::path::Path;
+use std::env;
+
+fn load_config() -> config::Config {
+ let args = env::args().collect::<Vec<_>>();
+ if args.len() != 2 {
+ eprintln!("Too {} command line arguments", if args.len() < 2 { "few" } else { "many" });
+ eprintln!("Usage: {} <path/to/config.json>", args[0]);
+ process::exit(1);
+ }
+
+ let config_path = Path::new(&args[1]);
+ match config::Config::from_path(config_path) {
+ Ok(config) => config,
+ Err(e) => {
+ eprintln!("Error reading configuration: {}", e);
+ process::exit(1);
+ },
+ }
+}
+
+fn get_listener_from_systemd() -> io::Result<TokioUnixListener> {
+ let mut fds = systemd_socket::listen_fds(true).unwrap_or(vec![]);
+ if fds.len() != 1 {
+ eprintln!("Too {} sockets passed from systemd", if fds.len() < 1 { "few" } else { "many" });
+ eprintln!("This tool only works with systemd socket activation.");
+ process::exit(1);
+ }
+ let fd = fds.remove(0);
+
+
+ #[cfg(not(target_vendor = "apple"))] // See note in `is_socket_unix`.
+ {
+ use nix::sys::socket::SockType;
+
+ if !systemd_socket::is_socket_unix(&fd, Some(SockType::Stream), Some(true), None)
+ .unwrap_or(false)
+ {
+ eprintln!("The socket from systemd is not a streaming UNIX socket");
+ process::exit(1);
+ }
+ }
+
+ let std_listener = StdUnixListener::from(fd);
+ std_listener.set_nonblocking(true)?; // Required by tokio::net::UnixListener::from_std().
+
+ let tokio_listener = TokioUnixListener::from_std(std_listener)?;
+ Ok(tokio_listener)
+}
+
+#[tokio::main]
+async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
+ let config = load_config();
+
+ let listener = get_listener_from_systemd()?;
+
+ // We start a loop to continuously accept incoming connections
+ loop {
+ let (stream, _) = listener.accept().await.expect("accepting new connection");
+ let io = TokioIo::new(stream);
+ let cfg = config.clone();
+
+ // Spawn a tokio task to serve multiple connections concurrently.
+ tokio::task::spawn(async move {
+ let service = service_fn(|req: Request<hyper::body::Incoming>| {
+ service::router(req, &cfg)
+ });
+
+ let conn = http1::Builder::new()
+ // On OSX, disabling keep alive prevents serve_connection from
+ // blocking and later returning an `Err` derived from `ENOTCONN`.
+ .keep_alive(false)
+ .serve_connection(io, service);
+
+ if let Err(err) = conn.await {
+ eprintln!("Error serving connection: {:?}", err);
+ }
+ });
+ }
+}
diff --git a/src/service.rs b/src/service.rs
new file mode 100644
index 0000000..a261bdf
--- /dev/null
+++ b/src/service.rs
@@ -0,0 +1,157 @@
+//! This module contains the service that is being served with Hyper (our HTTP server library). The
+//! functions in here are responsible for taking requests from the GitHub API and producing
+//! responses.
+
+use crate::config::{self, Config};
+
+use http_body_util::{combinators::BoxBody, BodyExt, Full, Empty};
+use hyper::body::{Body, Bytes};
+use hyper::header::{HeaderMap, HeaderValue};
+use hyper::{Request, Response, Method, StatusCode};
+
+use hmac::{Hmac, Mac};
+use sha2::Sha256;
+use std::num::ParseIntError;
+
+use tokio::process::Command;
+use tokio::io::AsyncWriteExt;
+use std::io;
+use std::process::{ExitStatus, Stdio};
+
+/// Alias for hasher implementing HMAC-SHA256.
+type HmacSha256 = Hmac<Sha256>;
+
+/// Dispatches HTTP requests to different handlers, returning their result.
+pub async fn router(
+ req: Request<hyper::body::Incoming>,
+ config: &Config,
+) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
+ match (req.method(), req.uri().path()) {
+ (&Method::POST, "/") => handle_webhook_post(req, config).await,
+ _ => Ok(empty_res(StatusCode::NOT_FOUND)),
+ }
+}
+
+async fn handle_webhook_post(
+ req: Request<hyper::body::Incoming>,
+ config: &Config,
+) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
+ let (head, body) = req.into_parts();
+
+ // Extract the event type early on. This allows us to exit before doing expensive signature
+ // checking, if the header is missing or invalid ASCII.
+ let event = match head.headers.get("X-GitHub-event").map(HeaderValue::to_str) {
+ Some(Ok(event)) => event,
+ Some(Err(_)) => return Ok(full_res("Invalid ASCII in header: X-GitHub-Event", StatusCode::BAD_REQUEST)),
+ None => return Ok(full_res("Missing header: X-GitHub-Event", StatusCode::BAD_REQUEST)),
+ };
+
+ // Read entire body into `Bytes`. We have to set an upper limit to protect the server from
+ // massive allocations.
+ let upper = body.size_hint().upper().unwrap_or(u64::MAX);
+ if upper > 1024 * 64 {
+ eprintln!("Rejecting request because payload is too large.");
+ return Ok(full_res("Body too big", StatusCode::PAYLOAD_TOO_LARGE));
+ }
+ let body = body.collect().await?.to_bytes();
+
+ // Now that we have read the entire body, we should validate the signature before proceeding.
+ if !validate_request(&config.secret, &head.headers, &body) {
+ eprintln!("Rejecting request becuase signature is missing or invaldi");
+ return Ok(full_res("Missing or invalid signature", StatusCode::BAD_REQUEST));
+ }
+
+ for command in &config.commands {
+ if command.event == event {
+ let command_clone = command.clone();
+ let body_clone = body.clone();
+ tokio::spawn(async move {
+ match run_command(&command_clone, body_clone.as_ref()).await {
+ Ok(s) => match s.code() {
+ Some(code) => println!("Command finished with exit code {}: {:?}", code, command_clone),
+ None => println!("Command finished without exit code: {:?}", command_clone),
+ },
+ Err(e) => eprintln!("Failed to spawn command: {:?}\nerror: {}", command_clone, e),
+ }
+ });
+ }
+ }
+
+ Ok(empty_res(StatusCode::NO_CONTENT))
+}
+
+async fn run_command(command: &config::Command, body: &[u8]) -> io::Result<ExitStatus> {
+ let mut child = Command::new(&command.command)
+ .stdin(Stdio::piped()) // We will feed the event data through stdin.
+ .stdout(Stdio::inherit())
+ .stderr(Stdio::inherit())
+ .args(&command.args)
+ .spawn()?;
+
+ // Feed data through stdin. Sure hope whatever a "deadlock" is doesn't happen here.
+ let mut child_stdin = child.stdin.take().expect("child has stdin");
+ child_stdin.write_all(body).await?;
+ drop(child_stdin);
+
+ Ok(child.wait().await?)
+}
+
+/// Utility to create an empty response.
+fn empty_res(status: StatusCode) -> Response<BoxBody<Bytes, hyper::Error>> {
+ let body = Empty::<Bytes>::new()
+ .map_err(|never| match never {})
+ .boxed();
+
+ let mut response = Response::new(body);
+ *response.status_mut() = status;
+ response
+}
+
+/// Utility to create a full (i.e. with content) response.
+fn full_res<T: Into<Bytes>>(
+ chunk: T,
+ status: StatusCode,
+) -> Response<BoxBody<Bytes, hyper::Error>> {
+ let body = Full::new(chunk.into())
+ .map_err(|never| match never {})
+ .boxed();
+
+ let mut response = Response::new(body);
+ *response.status_mut() = status;
+ response
+}
+
+/// Decodes a string slice into a string of bytes.
+///
+/// Implementation taken from [this stackoverflow post](https://stackoverflow.com/a/52992629).
+fn decode_hex(s: &str) -> Result<Vec<u8>, ParseIntError> {
+ (0..s.len())
+ .step_by(2)
+ .map(|i| u8::from_str_radix(&s[i..i + 2], 16))
+ .collect()
+}
+
+/// Validates the signature that GitHub attaches to events.
+///
+///
+fn validate_request(secret: &String, headers: &HeaderMap<HeaderValue>, body: &Bytes) -> bool {
+ // To verify the authenticity of the event, GitHub attaches a signature of the payload to
+ // every request. We extract the header. The header value will look something like this:
+ //
+ // x-hub-signature-256: sha256=6803d2a3e495fc4bd286d428ea4b794476a1ff1b72bbea4dfafd2477d5d89188
+ let maybe_signature = headers
+ .get("x-hub-signature-256")
+ .and_then(|hv| hv.to_str().ok()) // HeaderValue => &str
+ .and_then(|s| s.strip_prefix("sha256=")) // sha256=2843i4aklds... => 2843i4aklds...
+ .and_then(|s| decode_hex(s).ok()); // &str -> vec<u8>
+ let signature = match maybe_signature {
+ Some(s) => s,
+ None => return false, // Missing or invalid signature
+ };
+
+ // Now we independantly calculate a signature of the payload we just read, using the secret. If
+ // Github computed the signature with the same secret, we should be all good.
+ let mut mac = HmacSha256::new_from_slice(secret.as_bytes()).unwrap();
+ mac.update(&body);
+ mac.verify_slice(&signature).is_ok()
+}
diff --git a/src/systemd_socket.rs b/src/systemd_socket.rs
new file mode 100644
index 0000000..043fd8e
--- /dev/null
+++ b/src/systemd_socket.rs
@@ -0,0 +1,473 @@
+//! `systemd_socket` implements the daemon side of the socket activation. The interface is similar
+//! to the one provided by the systemd/sd-daemon library, but adjusted for easier usage in rust. It
+//! relies on `nix` for all low-level operations. All checks are ported from the systemd code.
+//!
+//! Enums required for socket type (`SockType`) and address family (`AddressFamily`) are reexported
+//! from nix.
+//!
+//! The library is based on [rust-systemd](https://github.com/jmesmon/rust-systemd) by Cody P
+//! Schafer, but it does not require any extra libraries and works on rust stable.
+
+// I'm hoping to bring this module with me to other packages, so let's just allow all the functions
+// which _are_ useful, just not for this project. That's why there are a lot of `allow(dead_code)`
+// in this module.
+
+use nix::fcntl;
+use nix::libc;
+use nix::sys::socket::{self, SockaddrLike};
+use nix::sys::stat;
+use nix::unistd::Pid;
+use std::collections::HashMap;
+use std::convert::From;
+use std::env;
+use std::error::Error as StdError;
+use std::fmt;
+use std::num::ParseIntError;
+use std::os::unix::io::{OwnedFd, RawFd};
+use std::os::fd::{AsFd, AsRawFd, FromRawFd};
+use std::path;
+
+pub use nix::sys::socket::SockType;
+pub use nix::sys::socket::AddressFamily;
+
+const VAR_FDS: &'static str = "LISTEN_FDS";
+const VAR_NAMES: &'static str = "LISTEN_FDNAMES";
+const VAR_PID: &'static str = "LISTEN_PID";
+
+#[derive(Debug, PartialEq)]
+pub enum Error {
+ Var(env::VarError),
+ Parse(ParseIntError),
+ DifferentProcess,
+ InvalidVariableValue,
+ Nix(nix::Error),
+}
+
+impl fmt::Display for Error {
+ fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
+ write!(f, "{}", self)
+ }
+}
+
+impl StdError for Error {
+ fn description(&self) -> &str {
+ match self {
+ &Error::InvalidVariableValue => "Environment variable could not be parsed",
+ &Error::DifferentProcess =>
+ "Environment variables are meant for a different process (pid mismatch)",
+ &Error::Var(_) => "Required environment variable missing or unreadable",
+ &Error::Parse(_) => "Could not parse number in 'LISTEN_FDS'",
+ &Error::Nix(_) => "Calling system function on socket failed",
+ }
+ }
+
+ fn cause(&self) -> Option<&dyn StdError> {
+ match self {
+ &Error::Var(ref e) => Some(e),
+ &Error::Parse(ref e) => Some(e),
+ &Error::Nix(ref e) => Some(e),
+ _ => None,
+ }
+ }
+}
+
+impl From<env::VarError> for Error {
+ fn from(e: env::VarError) -> Error {
+ Error::Var(e)
+ }
+}
+
+impl From<ParseIntError> for Error {
+ fn from(e: ParseIntError) -> Error {
+ Error::Parse(e)
+ }
+}
+
+impl From<nix::Error> for Error {
+ fn from(e: nix::Error) -> Error {
+ Error::Nix(e)
+ }
+}
+
+/// Encapsulates the possible failure modes of local functions.
+pub type Result<T> = std::result::Result<T, Error>;
+
+/// Number of the first passed file descriptor
+const LISTEN_FDS_START: RawFd = 3;
+
+fn unset_all_env() {
+ env::remove_var(VAR_PID);
+ env::remove_var(VAR_FDS);
+ env::remove_var(VAR_NAMES);
+}
+
+/// Returns the file descriptors passed in by init process. Removes the `$LISTEN_FDS` and
+/// `$LISTEN_PID` variables from the environment if `unset_environment` is `true`.
+pub fn listen_fds(unset_environment: bool) -> Result<Vec<OwnedFd>> {
+ let pid_str = env::var(VAR_PID)?;
+ let pid_raw: libc::pid_t = pid_str.parse()?;
+ let pid = Pid::from_raw(pid_raw);
+
+ if pid != nix::unistd::getpid() {
+ return Err(Error::DifferentProcess);
+ }
+
+ let fds_str = env::var(VAR_FDS)?;
+ let fds: libc::c_int = fds_str.parse()?;
+
+ if fds < 0 {
+ return Err(Error::InvalidVariableValue);
+ }
+
+ for fd in LISTEN_FDS_START..(LISTEN_FDS_START+fds) {
+ fcntl::fcntl(fd, fcntl::FcntlArg::F_SETFD(fcntl::FdFlag::FD_CLOEXEC))?;
+ }
+
+ if unset_environment {
+ unset_all_env();
+ }
+ let fd_vec: Vec<_> = (LISTEN_FDS_START .. (LISTEN_FDS_START+fds))
+ .map(|fd| unsafe { OwnedFd::from_raw_fd(fd) })
+ .collect();
+ Ok(fd_vec)
+}
+
+/// Returns file descriptors with names. Removes the `$LISTEN_FDS` and `$LISTEN_PID` variables from
+/// the environment if `unset_environment` is `true`.
+#[allow(unused)]
+pub fn listen_fds_with_names(unset_environment: bool) -> Result<HashMap<String, OwnedFd>> {
+ let names_str = env::var(VAR_NAMES)?;
+ let names: Vec<&str> = names_str.split(':').collect();
+
+ let fds: Vec<OwnedFd> = listen_fds(unset_environment)?;
+ if fds.len() != names.len() {
+ return Err(Error::InvalidVariableValue);
+ }
+
+ let mut map = HashMap::new();
+ for (name, fd) in names.into_iter().zip(fds) {
+ map.insert(name.to_string(), fd);
+ }
+ Ok(map)
+}
+
+/// Identifies whether the passed file descriptor is a FIFO. If a path is
+/// supplied, the file descriptor must also match the path.
+#[allow(unused)]
+pub fn is_fifo<T: AsRawFd>(fd: T, path: Option<&str>) -> Result<bool> {
+ let fs = stat::fstat(fd.as_raw_fd())?;
+ let mode = stat::SFlag::from_bits_truncate(fs.st_mode);
+ if !mode.contains(stat::SFlag::S_IFIFO) {
+ return Ok(false);
+ }
+ if let Some(path_str) = path {
+ let path_stat = match stat::stat(path::Path::new(path_str)) {
+ Ok(x) => x,
+ Err(_) => {return Ok(false)},
+ };
+ return Ok(path_stat.st_dev == fs.st_dev && path_stat.st_ino == fs.st_ino);
+ }
+ Ok(true)
+}
+
+/// Identifies whether the passed file descriptor is a special character device.
+/// If a path is supplied, the file descriptor must also match the path.
+#[allow(unused)]
+pub fn is_special<T: AsRawFd>(fd: T, path: Option<&str>) -> Result<bool> {
+ let fs = stat::fstat(fd.as_raw_fd())?;
+ let mode = stat::SFlag::from_bits_truncate(fs.st_mode);
+ if !mode.contains(stat::SFlag::S_IFREG) && !mode.contains(stat::SFlag::S_IFCHR) {
+ // path not comparable
+ return Ok(true);
+ }
+
+ if let Some(path_str) = path {
+ let path_stat = match stat::stat(path::Path::new(path_str)) {
+ Ok(x) => x,
+ Err(_) => {return Ok(false)},
+ };
+
+ let path_mode = stat::SFlag::from_bits_truncate(path_stat.st_mode);
+ if (mode & path_mode).contains(stat::SFlag::S_IFREG) {
+ return Ok(path_stat.st_dev == fs.st_dev && path_stat.st_ino == fs.st_ino);
+ }
+
+ if (mode & path_mode).contains(stat::SFlag::S_IFCHR) {
+ return Ok(path_stat.st_rdev == fs.st_rdev);
+ }
+
+ return Ok(false);
+ }
+
+ Ok(true)
+}
+
+/// Do checks common to all socket verification functions. (type, listening state)
+#[allow(unused)]
+fn is_socket_internal<T: AsFd>(fd: &T, socktype: Option<SockType>,
+ listening: Option<bool>) -> Result<bool> {
+ /*if fd < 0 {
+ return Err(Error::InvalidFdValue);
+ }*/
+
+ let fs = stat::fstat(fd.as_fd().as_raw_fd())?;
+ let mode = stat::SFlag::from_bits_truncate(fs.st_mode);
+ if !mode.contains(stat::SFlag::S_IFSOCK) {
+ return Ok(false);
+ }
+ if let Some(val) = socktype {
+ let typ: SockType = socket::getsockopt(&fd, socket::sockopt::SockType)?;
+ if typ != val {
+ return Ok(false);
+ }
+ }
+
+ if let Some(val) = listening {
+ // This is broken on MacOS, as according to [getsockopt(2)] and [this stackoverflow
+ // anser][so], `SO_ACCEPTCONN` is not
+ // supported at the `SOL_SOCKET` level. I assume this also applies to other platforms using
+ // the Darwin kernel, i.e. all Apple's platfroms.
+ //
+ // [getsockopt(2)]: https://developer.apple.com/library/archive/documentation/System/Conceptual/ManPages_iPhoneOS/man2/getsockopt.2.html
+ // [so]: https://stackoverflow.com/a/75943802
+ if cfg!(target_vendor = "apple") {
+ todo!("Getting listening state is not implemented on Apple's Darwin kernel");
+ }
+
+ let acc = socket::getsockopt(&fd, socket::sockopt::AcceptConn)?;
+ if acc != val {
+ return Ok(false);
+ }
+ }
+
+ Ok(true)
+}
+
+/// Identifies whether the passed file descriptor is a socket. If family,
+/// type, and listening state are supplied, they must match as well.
+#[allow(unused)]
+pub fn is_socket<T: AsFd>(fd: &T, family: Option<AddressFamily>, socktype: Option<SockType>,
+ listening: Option<bool>) -> Result<bool> {
+ if !is_socket_internal(fd, socktype, listening)? {
+ return Ok(false);
+ }
+
+ if let Some(f) = family {
+ let sock_addr: socket::SockaddrStorage = socket::getsockname(fd.as_fd().as_raw_fd())?;
+ let sock_family = sock_addr.family().unwrap();
+ if sock_family != f {
+ return Ok(false);
+ }
+ }
+
+ Ok(true)
+}
+
+/// Identifies whether the passed file descriptor is an Internet socket. If family, type, listening
+/// state, and/or port are supplied, they must match as well.
+pub fn is_socket_inet<T: AsFd>(fd: &T, family: Option<AddressFamily>, socktype: Option<SockType>,
+ listening: Option<bool>, port: Option<u16>) -> Result<bool> {
+ if !is_socket_internal(fd, socktype, listening)? {
+ return Ok(false);
+ }
+
+ let sock_addr: socket::SockaddrStorage = socket::getsockname(fd.as_fd().as_raw_fd())?;
+ let sock_family = sock_addr.family().unwrap();
+ if sock_family != AddressFamily::Inet && sock_family != AddressFamily::Inet6 {
+ return Ok(false);
+ }
+
+ if let Some(val) = family {
+ if sock_family != val {
+ return Ok(false);
+ }
+ }
+
+ if let Some(expected_port) = port {
+ let port = match sock_family {
+ socket::AddressFamily::Inet => sock_addr.as_sockaddr_in().unwrap().port(),
+ socket::AddressFamily::Inet6 => sock_addr.as_sockaddr_in6().unwrap().port(),
+ _ => unreachable!(),
+ };
+ if port != expected_port {
+ return Ok(false);
+ }
+ }
+
+ Ok(true)
+}
+
+/// Identifies whether the passed file descriptor is an AF_UNIX socket. If type are supplied, it
+/// must match as well. Path checking is currently unsupported and will be ignored
+#[allow(unused)]
+pub fn is_socket_unix<T: AsFd>(fd: &T, socktype: Option<SockType>, listening: Option<bool>,
+ path: Option<&str>) -> Result<bool> {
+ if !is_socket_internal(fd, socktype, listening)? {
+ return Ok(false);
+ }
+
+ let sock_addr: socket::SockaddrStorage = socket::getsockname(fd.as_fd().as_raw_fd())?;
+ let sock_family = sock_addr.family().unwrap();
+ if sock_family != AddressFamily::Unix {
+ return Ok(false);
+ }
+
+ if let Some(_val) = path {
+ // TODO: unsupported
+ }
+
+ Ok(true)
+}
+
+// TODO
+///// Identifies whether the passed file descriptor is a POSIX message queue. If a
+///// path is supplied, it will also verify the name.
+//pub fn is_mq(fd: RawFd, path: Option<&str>) -> Result<bool> {
+//}
+
+#[cfg(test)]
+mod tests {
+ use ::nix;
+ use ::lazy_static::lazy_static;
+ use ::std::env;
+ use ::std::os::unix::io::OwnedFd;
+ use ::std::os::fd::{AsRawFd, FromRawFd, RawFd};
+ use ::std::sync::{Mutex,MutexGuard};
+ use ::std::mem;
+
+ // Even with one -j1, cargo runs multiple tests at once. That doesn't work with environment
+ // variables, or specific socket ordering, so mutexes are required.
+ lazy_static! {
+ static ref LOCK: Mutex<()> = Mutex::new(());
+ }
+
+ fn lock_env<'a>() -> MutexGuard<'a, ()> {
+ // SAFETY: We can ignore `PoisonError`s since the ressource we are locking is just `()`.
+ // See: <https://stackoverflow.com/a/51694631>.
+ LOCK.lock().unwrap_or_else(|e| e.into_inner())
+ }
+
+ fn set_current_pid() {
+ let pid = nix::unistd::getpid();
+ env::set_var(super::VAR_PID, format!("{}", pid));
+ }
+
+ /// Create a new socket with the given `family` and `typ`e.
+ ///
+ /// This function is used by the `is_*` tests, so it returns an owned ressource (as opposed to
+ /// [`create_socket_with_fd`](self::create_socket_with_fd)).
+ fn create_socket(family: super::AddressFamily, typ: super::SockType) -> OwnedFd {
+ nix::sys::socket::socket(family, typ, nix::sys::socket::SockFlag::empty(), None).unwrap()
+ }
+
+ /// Create a new socket with the given `family` and `typ`e, asserting that it gets assigned a
+ /// specific fd.
+ ///
+ /// This function is used to simulate Systemd opening a socket for us, so the actual ressource
+ /// is ["forgotten"](std::mem::forget).
+ fn create_socket_with_fd(no: nix::libc::c_int, family: super::AddressFamily, typ: super::SockType) {
+ debug_assert!(no > 0, "Valid file descriptors are always positive");
+
+ // Allocate a socket. During normal operation, this would be done by Systemd before our
+ // program was even started.
+ let fd = create_socket(family, typ);
+ assert_eq!(fd.as_raw_fd(), no, "Expected new socket to have fd {} but got {}", no, fd.as_raw_fd());
+
+ // We don't want Rust to manage the ressource for us (YET), as this function is supposed to
+ // mimic how Systemd would file descriptors to us.
+ mem::forget(fd);
+ }
+
+ /// Returns a file descriptor for a regular file.
+ fn open_file() -> OwnedFd {
+ let path = ::std::path::Path::new("/etc/hosts");
+ let fd = nix::fcntl::open(path, nix::fcntl::OFlag::O_RDONLY, nix::sys::stat::Mode::empty()).unwrap();
+ unsafe { OwnedFd::from_raw_fd(fd) }
+ }
+
+ #[test]
+ fn listen_fds_success() {
+ let _l = lock_env();
+ set_current_pid();
+ let _fd = create_socket_with_fd(3, super::AddressFamily::Inet, super::SockType::Stream);
+ env::set_var(super::VAR_FDS, "1");
+ let fds = super::listen_fds(true).unwrap();
+ assert_eq!(fds.len(), 1);
+ assert_eq!(fds[0].as_raw_fd(), 3);
+ }
+
+ #[test]
+ fn names() {
+ let _l = lock_env();
+ set_current_pid();
+ env::set_var(super::VAR_FDS, "2");
+ env::set_var(super::VAR_NAMES, "a:b");
+ let _fd1 = create_socket_with_fd(3, super::AddressFamily::Inet, super::SockType::Stream);
+ let _fd2 = create_socket_with_fd(4, super::AddressFamily::Inet, super::SockType::Stream);
+ let fds = super::listen_fds_with_names(true).unwrap();
+ assert_eq!(fds.len(), 2);
+ assert_eq!(fds["a"].as_raw_fd(), 3);
+ assert_eq!(fds["b"].as_raw_fd(), 4);
+ }
+
+ #[test]
+ fn listen_fds_cleans() {
+ let _l = lock_env();
+ set_current_pid();
+ env::set_var(super::VAR_FDS, "0");
+ super::listen_fds(false).unwrap();
+ assert_eq!(env::var(super::VAR_FDS), Ok("0".into()));
+ super::listen_fds(true).unwrap();
+ assert_eq!(env::var(super::VAR_FDS), Err(env::VarError::NotPresent));
+ assert_eq!(env::var(super::VAR_PID), Err(env::VarError::NotPresent));
+ assert_eq!(env::var(super::VAR_NAMES), Err(env::VarError::NotPresent));
+ }
+
+ #[test]
+ fn is_socket() {
+ let _l = lock_env();
+
+ let fd = create_socket(super::AddressFamily::Inet, super::SockType::Stream);
+ assert!(super::is_socket(&fd, None, None, None).unwrap());
+ #[cfg(not(target_vendor = "apple"))]
+ assert!(super::is_socket(&fd, Some(super::AddressFamily::Inet),
+ Some(super::SockType::Stream), Some(false)).unwrap());
+ #[cfg(target_vendor = "apple")]
+ assert!(super::is_socket(&fd, Some(super::AddressFamily::Inet),
+ Some(super::SockType::Stream), None).unwrap());
+
+ let fd = open_file();
+ assert!(!super::is_socket(&fd, None, None, None).unwrap());
+ }
+
+ #[test]
+ fn is_socket_inet() {
+ let _l = lock_env();
+ let fd = create_socket(super::AddressFamily::Inet, super::SockType::Stream);
+ assert!(super::is_socket_inet(&fd, None, None, None, None).unwrap());
+ #[cfg(not(target_vendor = "apple"))]
+ assert!(super::is_socket_inet(&fd, Some(super::AddressFamily::Inet),
+ Some(super::SockType::Stream), Some(false), None).unwrap());
+ #[cfg(target_vendor = "apple")]
+ assert!(super::is_socket_inet(&fd, Some(super::AddressFamily::Inet),
+ Some(super::SockType::Stream), None, None).unwrap());
+
+ let fd = open_file();
+ assert!(!super::is_socket_inet(&fd, None, None, None, None).unwrap());
+ }
+
+ #[test]
+ fn is_socket_unix() {
+ let _l = lock_env();
+ let fd = create_socket(super::AddressFamily::Unix, super::SockType::Stream);
+ assert!(super::is_socket_unix(&fd, None, None, None).unwrap());
+ #[cfg(not(target_vendor = "apple"))]
+ assert!(super::is_socket_unix(&fd, Some(super::SockType::Stream),
+ Some(false), None).unwrap());
+ #[cfg(target_vendor = "apple")]
+ assert!(super::is_socket_unix(&fd, Some(super::SockType::Stream), None, None).unwrap());
+
+ let fd = open_file();
+ assert!(!super::is_socket_unix(&fd, None, None, None).unwrap());
+ }
+}