From f5c119fdaf444fc0b1b59c1e07fd32a5f6ddf069 Mon Sep 17 00:00:00 2001 From: Linnnus Date: Tue, 1 Oct 2024 22:29:21 +0200 Subject: Initial commit --- src/config.rs | 251 +++++++++++++++++++++++++++ src/main.rs | 97 +++++++++++ src/service.rs | 157 +++++++++++++++++ src/systemd_socket.rs | 473 ++++++++++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 978 insertions(+) create mode 100644 src/config.rs create mode 100644 src/main.rs create mode 100644 src/service.rs create mode 100644 src/systemd_socket.rs (limited to 'src') diff --git a/src/config.rs b/src/config.rs new file mode 100644 index 0000000..e70043e --- /dev/null +++ b/src/config.rs @@ -0,0 +1,251 @@ +use std::path::{Path, PathBuf}; +use std::fs::{self, File}; +use std::io; +use std::fmt::{self, Display}; +use serde::Deserialize; + +/// All the application configuration is stored in this structure. +#[derive(PartialEq, Clone, Debug)] +pub struct Config { + /// The secret string shared with GitHub that is used to verify signed requests. + pub secret: String, + + /// Event-command pairs. Each element of this array should be matched (and optionally executed) + /// against the commands in gaide. + pub commands: Vec, +} + +impl Config { + pub fn from_path>(path: P) -> Result { + let raw_config = RawConfig::from_path(path)?; + let secret = fs::read_to_string(raw_config.secret_file)?; + Ok(Config { + secret, + commands: raw_config.commands, + }) + } +} + +/// This struct reflects the actual JSON on disk. It is further processed before being returned to +/// the rest of the application. +#[derive(Deserialize, Clone, Debug, PartialEq)] +struct RawConfig { + /// Path to file containing the secret that was shared with GitHub. + secret_file: PathBuf, + + /// Event-command pairs. + commands: Vec, +} + +/// Represents an event-command pair. The command is run whenever the given event is received from +/// GitHub's API. +#[derive(Deserialize, Clone, Debug, PartialEq)] +pub struct Command { + /// The name of an event from the GitHub API. A full list of events can be found in [GitHub's + /// documenation][gh-events]. + /// + /// [gh-events]: https://docs.github.com/en/webhooks/webhook-events-and-payloads + pub event: String, + + /// Path to the program to be executed when [`event`](event) occurs. + pub command: String, + + /// Additional arguments to bass to [`command`](command). + #[serde(default)] + pub args: Vec, +} + +/* +/// Serde helper which disallows empty strings for [`PathBuf`s](std::path::PathBuf). Based on [this +/// StackOverflow post][so]. +/// +/// [so]: https://stackoverflow.com/a/46755370 +fn string_as_nonempty_pathbuf<'de, D>(deserializer: D) -> Result +where + D: Deserializer<'de> +{ + let raw: &str = Deserialize::deserialize(deserializer)?; + if raw.is_empty() { + Err(de::Error::custom("path cannot be empty")) + } else { + Ok(PathBuf::from(raw)) + } +} +*/ + +/// Errors that can occur when reading configuration. +#[derive(Debug)] +pub enum ConfigError { + /// An IO error occured, such as failing to read the file. + Io(io::Error), + /// Decoding the file failed, e.g. if JSON is missing comma. + SerdeError(serde_json::Error), +} + +impl From for ConfigError { + fn from(e: io::Error) -> ConfigError { + ConfigError::Io(e) + } +} + +impl From for ConfigError { + fn from(e: serde_json::Error) -> ConfigError { + use serde_json::error::Category; + match e.classify() { + Category::Io => ConfigError::Io(e.into()), + _ => ConfigError::SerdeError(e), + } + } +} + +impl Display for ConfigError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> { + match self { + ConfigError::Io(e) => write!(f, "io error: {}", e), + ConfigError::SerdeError(e) => write!(f, "decoding error: {}", e), + } + } +} + +impl RawConfig { + pub fn from_path>(path: P) -> Result { + let file = File::open(path.as_ref())?; + let config: Self = serde_json::from_reader(file)?; + config.validate()?; + Ok(config) + } + + #[allow(dead_code)] // Useful for tests. + pub(self) fn from_str(s: &str) -> Result { + let config: Self = serde_json::from_str(s)?; + config.validate()?; + Ok(config) + } + + fn validate(&self) -> Result<(), ConfigError> { + if self.secret_file.is_relative() { + eprintln!("warning: configuration key `.secret_file` is relative path. This will be resolved relative to server's CWD at runtime which is most likely not what you want."); + // " <- Fix shitty Vim syntax highlighting + } + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::{Config, Command, RawConfig, ConfigError}; + use std::path::Path; + + macro_rules! assert_matches { + ( $e:expr , $pat:pat ) => { + assert_matches!($e, $pat => ()) + }; + ( $e:expr , $pat:pat => $c:expr ) => { + match $e { + $pat => $c, + ref e => panic!("assertion failed: `{:?}` does not match `{}`", + e, stringify!($pat)) + } + }; + } + + macro_rules! assert_contains { + ( $a:expr , $b:expr ) => { + let a_string: String = $a.to_string(); + let b_string: String = $b.to_string(); + + if !a_string.contains(&b_string) { + panic!("assertion failed: expected {:?} to contain {:?}", a_string, b_string) + } + }; + } + + #[test] + fn load_valid_raw_config() { + let config_json = r#" + { + "secret_file": "/path/to/secret.txt", + + "commands": [ + { + "event": "ping", + "command": "/usr/bin/handle-ping", + "args": [] + } + ] + } + "#; + let parsed_config = RawConfig::from_str(config_json).expect("valid config"); + let expected_config = RawConfig { + secret_file: Path::new("/path/to/secret.txt").to_path_buf(), + commands: vec![ + Command { + event: "ping".to_string(), + command: "/usr/bin/handle-ping".to_string(), + args: vec![], + }, + ], + }; + assert_eq!(parsed_config, expected_config); + } + + #[test] + fn args_are_optional() { + let command_json = r#" + { + "event": "ping", + "command": "/usr/bin/handle-ping" + } + "#; + let parsed_command: Command = serde_json::from_str(command_json) + .expect("valid configuration"); + let expected_command = Command { + event: "ping".to_string(), + command: "/usr/bin/handle-ping".to_string(), + args: vec![], + }; + assert_eq!(expected_command, parsed_command); + } + + #[test] + fn invalid_json_gives_error() { + // This JSON has a trailing comma, which isn't allowed. + let config_json = r#" + { + "secret_file": "blah", + "commands": [], + } + "#; + let result = RawConfig::from_str(config_json); + let err = assert_matches!(result, Err(ConfigError::SerdeError(e)) => e); + assert_eq!(err.line(), 5); + assert_eq!(err.column(), 13); + assert!(err.is_syntax()); + } + + #[test] + fn read_valid_config() { + let parse_result = Config::from_path("examples/config.json"); + let parsed_config = assert_matches!(parse_result, Ok(c @ Config { .. }) => c); + let expected_config = Config { + secret: "mysecret".to_string(), + commands: vec![ + Command { + event: "ping".to_string(), + command: "/bin/echo".to_string(), + args: vec![ + "Got ping event!!".to_string() + ], + }, + Command { + event: "push".to_string(), + command: "/bin/echo".to_string(), + args: vec![ + "Got push event!!".to_string() + ], + }, + ], + }; + assert_eq!(parsed_config, expected_config); + } +} diff --git a/src/main.rs b/src/main.rs new file mode 100644 index 0000000..c09aa46 --- /dev/null +++ b/src/main.rs @@ -0,0 +1,97 @@ +// The systemd_socket module contains a lot of dead code which is only used in tests, but which I +// would like to keep up to date in case I need the module for another project. +#[allow(dead_code)] + +mod systemd_socket; +mod service; +mod config; + +use hyper::Request; +use hyper::server::conn::http1; +use hyper::service::service_fn; +use hyper_util::rt::TokioIo; + +use std::os::unix::net::UnixListener as StdUnixListener; +use tokio::net::UnixListener as TokioUnixListener; +use std::io; +use std::process; +use std::path::Path; +use std::env; + +fn load_config() -> config::Config { + let args = env::args().collect::>(); + if args.len() != 2 { + eprintln!("Too {} command line arguments", if args.len() < 2 { "few" } else { "many" }); + eprintln!("Usage: {} ", args[0]); + process::exit(1); + } + + let config_path = Path::new(&args[1]); + match config::Config::from_path(config_path) { + Ok(config) => config, + Err(e) => { + eprintln!("Error reading configuration: {}", e); + process::exit(1); + }, + } +} + +fn get_listener_from_systemd() -> io::Result { + let mut fds = systemd_socket::listen_fds(true).unwrap_or(vec![]); + if fds.len() != 1 { + eprintln!("Too {} sockets passed from systemd", if fds.len() < 1 { "few" } else { "many" }); + eprintln!("This tool only works with systemd socket activation."); + process::exit(1); + } + let fd = fds.remove(0); + + + #[cfg(not(target_vendor = "apple"))] // See note in `is_socket_unix`. + { + use nix::sys::socket::SockType; + + if !systemd_socket::is_socket_unix(&fd, Some(SockType::Stream), Some(true), None) + .unwrap_or(false) + { + eprintln!("The socket from systemd is not a streaming UNIX socket"); + process::exit(1); + } + } + + let std_listener = StdUnixListener::from(fd); + std_listener.set_nonblocking(true)?; // Required by tokio::net::UnixListener::from_std(). + + let tokio_listener = TokioUnixListener::from_std(std_listener)?; + Ok(tokio_listener) +} + +#[tokio::main] +async fn main() -> Result<(), Box> { + let config = load_config(); + + let listener = get_listener_from_systemd()?; + + // We start a loop to continuously accept incoming connections + loop { + let (stream, _) = listener.accept().await.expect("accepting new connection"); + let io = TokioIo::new(stream); + let cfg = config.clone(); + + // Spawn a tokio task to serve multiple connections concurrently. + tokio::task::spawn(async move { + let service = service_fn(|req: Request| { + service::router(req, &cfg) + }); + + let conn = http1::Builder::new() + // On OSX, disabling keep alive prevents serve_connection from + // blocking and later returning an `Err` derived from `ENOTCONN`. + .keep_alive(false) + .serve_connection(io, service); + + if let Err(err) = conn.await { + eprintln!("Error serving connection: {:?}", err); + } + }); + } +} diff --git a/src/service.rs b/src/service.rs new file mode 100644 index 0000000..a261bdf --- /dev/null +++ b/src/service.rs @@ -0,0 +1,157 @@ +//! This module contains the service that is being served with Hyper (our HTTP server library). The +//! functions in here are responsible for taking requests from the GitHub API and producing +//! responses. + +use crate::config::{self, Config}; + +use http_body_util::{combinators::BoxBody, BodyExt, Full, Empty}; +use hyper::body::{Body, Bytes}; +use hyper::header::{HeaderMap, HeaderValue}; +use hyper::{Request, Response, Method, StatusCode}; + +use hmac::{Hmac, Mac}; +use sha2::Sha256; +use std::num::ParseIntError; + +use tokio::process::Command; +use tokio::io::AsyncWriteExt; +use std::io; +use std::process::{ExitStatus, Stdio}; + +/// Alias for hasher implementing HMAC-SHA256. +type HmacSha256 = Hmac; + +/// Dispatches HTTP requests to different handlers, returning their result. +pub async fn router( + req: Request, + config: &Config, +) -> Result>, hyper::Error> { + match (req.method(), req.uri().path()) { + (&Method::POST, "/") => handle_webhook_post(req, config).await, + _ => Ok(empty_res(StatusCode::NOT_FOUND)), + } +} + +async fn handle_webhook_post( + req: Request, + config: &Config, +) -> Result>, hyper::Error> { + let (head, body) = req.into_parts(); + + // Extract the event type early on. This allows us to exit before doing expensive signature + // checking, if the header is missing or invalid ASCII. + let event = match head.headers.get("X-GitHub-event").map(HeaderValue::to_str) { + Some(Ok(event)) => event, + Some(Err(_)) => return Ok(full_res("Invalid ASCII in header: X-GitHub-Event", StatusCode::BAD_REQUEST)), + None => return Ok(full_res("Missing header: X-GitHub-Event", StatusCode::BAD_REQUEST)), + }; + + // Read entire body into `Bytes`. We have to set an upper limit to protect the server from + // massive allocations. + let upper = body.size_hint().upper().unwrap_or(u64::MAX); + if upper > 1024 * 64 { + eprintln!("Rejecting request because payload is too large."); + return Ok(full_res("Body too big", StatusCode::PAYLOAD_TOO_LARGE)); + } + let body = body.collect().await?.to_bytes(); + + // Now that we have read the entire body, we should validate the signature before proceeding. + if !validate_request(&config.secret, &head.headers, &body) { + eprintln!("Rejecting request becuase signature is missing or invaldi"); + return Ok(full_res("Missing or invalid signature", StatusCode::BAD_REQUEST)); + } + + for command in &config.commands { + if command.event == event { + let command_clone = command.clone(); + let body_clone = body.clone(); + tokio::spawn(async move { + match run_command(&command_clone, body_clone.as_ref()).await { + Ok(s) => match s.code() { + Some(code) => println!("Command finished with exit code {}: {:?}", code, command_clone), + None => println!("Command finished without exit code: {:?}", command_clone), + }, + Err(e) => eprintln!("Failed to spawn command: {:?}\nerror: {}", command_clone, e), + } + }); + } + } + + Ok(empty_res(StatusCode::NO_CONTENT)) +} + +async fn run_command(command: &config::Command, body: &[u8]) -> io::Result { + let mut child = Command::new(&command.command) + .stdin(Stdio::piped()) // We will feed the event data through stdin. + .stdout(Stdio::inherit()) + .stderr(Stdio::inherit()) + .args(&command.args) + .spawn()?; + + // Feed data through stdin. Sure hope whatever a "deadlock" is doesn't happen here. + let mut child_stdin = child.stdin.take().expect("child has stdin"); + child_stdin.write_all(body).await?; + drop(child_stdin); + + Ok(child.wait().await?) +} + +/// Utility to create an empty response. +fn empty_res(status: StatusCode) -> Response> { + let body = Empty::::new() + .map_err(|never| match never {}) + .boxed(); + + let mut response = Response::new(body); + *response.status_mut() = status; + response +} + +/// Utility to create a full (i.e. with content) response. +fn full_res>( + chunk: T, + status: StatusCode, +) -> Response> { + let body = Full::new(chunk.into()) + .map_err(|never| match never {}) + .boxed(); + + let mut response = Response::new(body); + *response.status_mut() = status; + response +} + +/// Decodes a string slice into a string of bytes. +/// +/// Implementation taken from [this stackoverflow post](https://stackoverflow.com/a/52992629). +fn decode_hex(s: &str) -> Result, ParseIntError> { + (0..s.len()) + .step_by(2) + .map(|i| u8::from_str_radix(&s[i..i + 2], 16)) + .collect() +} + +/// Validates the signature that GitHub attaches to events. +/// +/// +fn validate_request(secret: &String, headers: &HeaderMap, body: &Bytes) -> bool { + // To verify the authenticity of the event, GitHub attaches a signature of the payload to + // every request. We extract the header. The header value will look something like this: + // + // x-hub-signature-256: sha256=6803d2a3e495fc4bd286d428ea4b794476a1ff1b72bbea4dfafd2477d5d89188 + let maybe_signature = headers + .get("x-hub-signature-256") + .and_then(|hv| hv.to_str().ok()) // HeaderValue => &str + .and_then(|s| s.strip_prefix("sha256=")) // sha256=2843i4aklds... => 2843i4aklds... + .and_then(|s| decode_hex(s).ok()); // &str -> vec + let signature = match maybe_signature { + Some(s) => s, + None => return false, // Missing or invalid signature + }; + + // Now we independantly calculate a signature of the payload we just read, using the secret. If + // Github computed the signature with the same secret, we should be all good. + let mut mac = HmacSha256::new_from_slice(secret.as_bytes()).unwrap(); + mac.update(&body); + mac.verify_slice(&signature).is_ok() +} diff --git a/src/systemd_socket.rs b/src/systemd_socket.rs new file mode 100644 index 0000000..043fd8e --- /dev/null +++ b/src/systemd_socket.rs @@ -0,0 +1,473 @@ +//! `systemd_socket` implements the daemon side of the socket activation. The interface is similar +//! to the one provided by the systemd/sd-daemon library, but adjusted for easier usage in rust. It +//! relies on `nix` for all low-level operations. All checks are ported from the systemd code. +//! +//! Enums required for socket type (`SockType`) and address family (`AddressFamily`) are reexported +//! from nix. +//! +//! The library is based on [rust-systemd](https://github.com/jmesmon/rust-systemd) by Cody P +//! Schafer, but it does not require any extra libraries and works on rust stable. + +// I'm hoping to bring this module with me to other packages, so let's just allow all the functions +// which _are_ useful, just not for this project. That's why there are a lot of `allow(dead_code)` +// in this module. + +use nix::fcntl; +use nix::libc; +use nix::sys::socket::{self, SockaddrLike}; +use nix::sys::stat; +use nix::unistd::Pid; +use std::collections::HashMap; +use std::convert::From; +use std::env; +use std::error::Error as StdError; +use std::fmt; +use std::num::ParseIntError; +use std::os::unix::io::{OwnedFd, RawFd}; +use std::os::fd::{AsFd, AsRawFd, FromRawFd}; +use std::path; + +pub use nix::sys::socket::SockType; +pub use nix::sys::socket::AddressFamily; + +const VAR_FDS: &'static str = "LISTEN_FDS"; +const VAR_NAMES: &'static str = "LISTEN_FDNAMES"; +const VAR_PID: &'static str = "LISTEN_PID"; + +#[derive(Debug, PartialEq)] +pub enum Error { + Var(env::VarError), + Parse(ParseIntError), + DifferentProcess, + InvalidVariableValue, + Nix(nix::Error), +} + +impl fmt::Display for Error { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{}", self) + } +} + +impl StdError for Error { + fn description(&self) -> &str { + match self { + &Error::InvalidVariableValue => "Environment variable could not be parsed", + &Error::DifferentProcess => + "Environment variables are meant for a different process (pid mismatch)", + &Error::Var(_) => "Required environment variable missing or unreadable", + &Error::Parse(_) => "Could not parse number in 'LISTEN_FDS'", + &Error::Nix(_) => "Calling system function on socket failed", + } + } + + fn cause(&self) -> Option<&dyn StdError> { + match self { + &Error::Var(ref e) => Some(e), + &Error::Parse(ref e) => Some(e), + &Error::Nix(ref e) => Some(e), + _ => None, + } + } +} + +impl From for Error { + fn from(e: env::VarError) -> Error { + Error::Var(e) + } +} + +impl From for Error { + fn from(e: ParseIntError) -> Error { + Error::Parse(e) + } +} + +impl From for Error { + fn from(e: nix::Error) -> Error { + Error::Nix(e) + } +} + +/// Encapsulates the possible failure modes of local functions. +pub type Result = std::result::Result; + +/// Number of the first passed file descriptor +const LISTEN_FDS_START: RawFd = 3; + +fn unset_all_env() { + env::remove_var(VAR_PID); + env::remove_var(VAR_FDS); + env::remove_var(VAR_NAMES); +} + +/// Returns the file descriptors passed in by init process. Removes the `$LISTEN_FDS` and +/// `$LISTEN_PID` variables from the environment if `unset_environment` is `true`. +pub fn listen_fds(unset_environment: bool) -> Result> { + let pid_str = env::var(VAR_PID)?; + let pid_raw: libc::pid_t = pid_str.parse()?; + let pid = Pid::from_raw(pid_raw); + + if pid != nix::unistd::getpid() { + return Err(Error::DifferentProcess); + } + + let fds_str = env::var(VAR_FDS)?; + let fds: libc::c_int = fds_str.parse()?; + + if fds < 0 { + return Err(Error::InvalidVariableValue); + } + + for fd in LISTEN_FDS_START..(LISTEN_FDS_START+fds) { + fcntl::fcntl(fd, fcntl::FcntlArg::F_SETFD(fcntl::FdFlag::FD_CLOEXEC))?; + } + + if unset_environment { + unset_all_env(); + } + let fd_vec: Vec<_> = (LISTEN_FDS_START .. (LISTEN_FDS_START+fds)) + .map(|fd| unsafe { OwnedFd::from_raw_fd(fd) }) + .collect(); + Ok(fd_vec) +} + +/// Returns file descriptors with names. Removes the `$LISTEN_FDS` and `$LISTEN_PID` variables from +/// the environment if `unset_environment` is `true`. +#[allow(unused)] +pub fn listen_fds_with_names(unset_environment: bool) -> Result> { + let names_str = env::var(VAR_NAMES)?; + let names: Vec<&str> = names_str.split(':').collect(); + + let fds: Vec = listen_fds(unset_environment)?; + if fds.len() != names.len() { + return Err(Error::InvalidVariableValue); + } + + let mut map = HashMap::new(); + for (name, fd) in names.into_iter().zip(fds) { + map.insert(name.to_string(), fd); + } + Ok(map) +} + +/// Identifies whether the passed file descriptor is a FIFO. If a path is +/// supplied, the file descriptor must also match the path. +#[allow(unused)] +pub fn is_fifo(fd: T, path: Option<&str>) -> Result { + let fs = stat::fstat(fd.as_raw_fd())?; + let mode = stat::SFlag::from_bits_truncate(fs.st_mode); + if !mode.contains(stat::SFlag::S_IFIFO) { + return Ok(false); + } + if let Some(path_str) = path { + let path_stat = match stat::stat(path::Path::new(path_str)) { + Ok(x) => x, + Err(_) => {return Ok(false)}, + }; + return Ok(path_stat.st_dev == fs.st_dev && path_stat.st_ino == fs.st_ino); + } + Ok(true) +} + +/// Identifies whether the passed file descriptor is a special character device. +/// If a path is supplied, the file descriptor must also match the path. +#[allow(unused)] +pub fn is_special(fd: T, path: Option<&str>) -> Result { + let fs = stat::fstat(fd.as_raw_fd())?; + let mode = stat::SFlag::from_bits_truncate(fs.st_mode); + if !mode.contains(stat::SFlag::S_IFREG) && !mode.contains(stat::SFlag::S_IFCHR) { + // path not comparable + return Ok(true); + } + + if let Some(path_str) = path { + let path_stat = match stat::stat(path::Path::new(path_str)) { + Ok(x) => x, + Err(_) => {return Ok(false)}, + }; + + let path_mode = stat::SFlag::from_bits_truncate(path_stat.st_mode); + if (mode & path_mode).contains(stat::SFlag::S_IFREG) { + return Ok(path_stat.st_dev == fs.st_dev && path_stat.st_ino == fs.st_ino); + } + + if (mode & path_mode).contains(stat::SFlag::S_IFCHR) { + return Ok(path_stat.st_rdev == fs.st_rdev); + } + + return Ok(false); + } + + Ok(true) +} + +/// Do checks common to all socket verification functions. (type, listening state) +#[allow(unused)] +fn is_socket_internal(fd: &T, socktype: Option, + listening: Option) -> Result { + /*if fd < 0 { + return Err(Error::InvalidFdValue); + }*/ + + let fs = stat::fstat(fd.as_fd().as_raw_fd())?; + let mode = stat::SFlag::from_bits_truncate(fs.st_mode); + if !mode.contains(stat::SFlag::S_IFSOCK) { + return Ok(false); + } + if let Some(val) = socktype { + let typ: SockType = socket::getsockopt(&fd, socket::sockopt::SockType)?; + if typ != val { + return Ok(false); + } + } + + if let Some(val) = listening { + // This is broken on MacOS, as according to [getsockopt(2)] and [this stackoverflow + // anser][so], `SO_ACCEPTCONN` is not + // supported at the `SOL_SOCKET` level. I assume this also applies to other platforms using + // the Darwin kernel, i.e. all Apple's platfroms. + // + // [getsockopt(2)]: https://developer.apple.com/library/archive/documentation/System/Conceptual/ManPages_iPhoneOS/man2/getsockopt.2.html + // [so]: https://stackoverflow.com/a/75943802 + if cfg!(target_vendor = "apple") { + todo!("Getting listening state is not implemented on Apple's Darwin kernel"); + } + + let acc = socket::getsockopt(&fd, socket::sockopt::AcceptConn)?; + if acc != val { + return Ok(false); + } + } + + Ok(true) +} + +/// Identifies whether the passed file descriptor is a socket. If family, +/// type, and listening state are supplied, they must match as well. +#[allow(unused)] +pub fn is_socket(fd: &T, family: Option, socktype: Option, + listening: Option) -> Result { + if !is_socket_internal(fd, socktype, listening)? { + return Ok(false); + } + + if let Some(f) = family { + let sock_addr: socket::SockaddrStorage = socket::getsockname(fd.as_fd().as_raw_fd())?; + let sock_family = sock_addr.family().unwrap(); + if sock_family != f { + return Ok(false); + } + } + + Ok(true) +} + +/// Identifies whether the passed file descriptor is an Internet socket. If family, type, listening +/// state, and/or port are supplied, they must match as well. +pub fn is_socket_inet(fd: &T, family: Option, socktype: Option, + listening: Option, port: Option) -> Result { + if !is_socket_internal(fd, socktype, listening)? { + return Ok(false); + } + + let sock_addr: socket::SockaddrStorage = socket::getsockname(fd.as_fd().as_raw_fd())?; + let sock_family = sock_addr.family().unwrap(); + if sock_family != AddressFamily::Inet && sock_family != AddressFamily::Inet6 { + return Ok(false); + } + + if let Some(val) = family { + if sock_family != val { + return Ok(false); + } + } + + if let Some(expected_port) = port { + let port = match sock_family { + socket::AddressFamily::Inet => sock_addr.as_sockaddr_in().unwrap().port(), + socket::AddressFamily::Inet6 => sock_addr.as_sockaddr_in6().unwrap().port(), + _ => unreachable!(), + }; + if port != expected_port { + return Ok(false); + } + } + + Ok(true) +} + +/// Identifies whether the passed file descriptor is an AF_UNIX socket. If type are supplied, it +/// must match as well. Path checking is currently unsupported and will be ignored +#[allow(unused)] +pub fn is_socket_unix(fd: &T, socktype: Option, listening: Option, + path: Option<&str>) -> Result { + if !is_socket_internal(fd, socktype, listening)? { + return Ok(false); + } + + let sock_addr: socket::SockaddrStorage = socket::getsockname(fd.as_fd().as_raw_fd())?; + let sock_family = sock_addr.family().unwrap(); + if sock_family != AddressFamily::Unix { + return Ok(false); + } + + if let Some(_val) = path { + // TODO: unsupported + } + + Ok(true) +} + +// TODO +///// Identifies whether the passed file descriptor is a POSIX message queue. If a +///// path is supplied, it will also verify the name. +//pub fn is_mq(fd: RawFd, path: Option<&str>) -> Result { +//} + +#[cfg(test)] +mod tests { + use ::nix; + use ::lazy_static::lazy_static; + use ::std::env; + use ::std::os::unix::io::OwnedFd; + use ::std::os::fd::{AsRawFd, FromRawFd, RawFd}; + use ::std::sync::{Mutex,MutexGuard}; + use ::std::mem; + + // Even with one -j1, cargo runs multiple tests at once. That doesn't work with environment + // variables, or specific socket ordering, so mutexes are required. + lazy_static! { + static ref LOCK: Mutex<()> = Mutex::new(()); + } + + fn lock_env<'a>() -> MutexGuard<'a, ()> { + // SAFETY: We can ignore `PoisonError`s since the ressource we are locking is just `()`. + // See: . + LOCK.lock().unwrap_or_else(|e| e.into_inner()) + } + + fn set_current_pid() { + let pid = nix::unistd::getpid(); + env::set_var(super::VAR_PID, format!("{}", pid)); + } + + /// Create a new socket with the given `family` and `typ`e. + /// + /// This function is used by the `is_*` tests, so it returns an owned ressource (as opposed to + /// [`create_socket_with_fd`](self::create_socket_with_fd)). + fn create_socket(family: super::AddressFamily, typ: super::SockType) -> OwnedFd { + nix::sys::socket::socket(family, typ, nix::sys::socket::SockFlag::empty(), None).unwrap() + } + + /// Create a new socket with the given `family` and `typ`e, asserting that it gets assigned a + /// specific fd. + /// + /// This function is used to simulate Systemd opening a socket for us, so the actual ressource + /// is ["forgotten"](std::mem::forget). + fn create_socket_with_fd(no: nix::libc::c_int, family: super::AddressFamily, typ: super::SockType) { + debug_assert!(no > 0, "Valid file descriptors are always positive"); + + // Allocate a socket. During normal operation, this would be done by Systemd before our + // program was even started. + let fd = create_socket(family, typ); + assert_eq!(fd.as_raw_fd(), no, "Expected new socket to have fd {} but got {}", no, fd.as_raw_fd()); + + // We don't want Rust to manage the ressource for us (YET), as this function is supposed to + // mimic how Systemd would file descriptors to us. + mem::forget(fd); + } + + /// Returns a file descriptor for a regular file. + fn open_file() -> OwnedFd { + let path = ::std::path::Path::new("/etc/hosts"); + let fd = nix::fcntl::open(path, nix::fcntl::OFlag::O_RDONLY, nix::sys::stat::Mode::empty()).unwrap(); + unsafe { OwnedFd::from_raw_fd(fd) } + } + + #[test] + fn listen_fds_success() { + let _l = lock_env(); + set_current_pid(); + let _fd = create_socket_with_fd(3, super::AddressFamily::Inet, super::SockType::Stream); + env::set_var(super::VAR_FDS, "1"); + let fds = super::listen_fds(true).unwrap(); + assert_eq!(fds.len(), 1); + assert_eq!(fds[0].as_raw_fd(), 3); + } + + #[test] + fn names() { + let _l = lock_env(); + set_current_pid(); + env::set_var(super::VAR_FDS, "2"); + env::set_var(super::VAR_NAMES, "a:b"); + let _fd1 = create_socket_with_fd(3, super::AddressFamily::Inet, super::SockType::Stream); + let _fd2 = create_socket_with_fd(4, super::AddressFamily::Inet, super::SockType::Stream); + let fds = super::listen_fds_with_names(true).unwrap(); + assert_eq!(fds.len(), 2); + assert_eq!(fds["a"].as_raw_fd(), 3); + assert_eq!(fds["b"].as_raw_fd(), 4); + } + + #[test] + fn listen_fds_cleans() { + let _l = lock_env(); + set_current_pid(); + env::set_var(super::VAR_FDS, "0"); + super::listen_fds(false).unwrap(); + assert_eq!(env::var(super::VAR_FDS), Ok("0".into())); + super::listen_fds(true).unwrap(); + assert_eq!(env::var(super::VAR_FDS), Err(env::VarError::NotPresent)); + assert_eq!(env::var(super::VAR_PID), Err(env::VarError::NotPresent)); + assert_eq!(env::var(super::VAR_NAMES), Err(env::VarError::NotPresent)); + } + + #[test] + fn is_socket() { + let _l = lock_env(); + + let fd = create_socket(super::AddressFamily::Inet, super::SockType::Stream); + assert!(super::is_socket(&fd, None, None, None).unwrap()); + #[cfg(not(target_vendor = "apple"))] + assert!(super::is_socket(&fd, Some(super::AddressFamily::Inet), + Some(super::SockType::Stream), Some(false)).unwrap()); + #[cfg(target_vendor = "apple")] + assert!(super::is_socket(&fd, Some(super::AddressFamily::Inet), + Some(super::SockType::Stream), None).unwrap()); + + let fd = open_file(); + assert!(!super::is_socket(&fd, None, None, None).unwrap()); + } + + #[test] + fn is_socket_inet() { + let _l = lock_env(); + let fd = create_socket(super::AddressFamily::Inet, super::SockType::Stream); + assert!(super::is_socket_inet(&fd, None, None, None, None).unwrap()); + #[cfg(not(target_vendor = "apple"))] + assert!(super::is_socket_inet(&fd, Some(super::AddressFamily::Inet), + Some(super::SockType::Stream), Some(false), None).unwrap()); + #[cfg(target_vendor = "apple")] + assert!(super::is_socket_inet(&fd, Some(super::AddressFamily::Inet), + Some(super::SockType::Stream), None, None).unwrap()); + + let fd = open_file(); + assert!(!super::is_socket_inet(&fd, None, None, None, None).unwrap()); + } + + #[test] + fn is_socket_unix() { + let _l = lock_env(); + let fd = create_socket(super::AddressFamily::Unix, super::SockType::Stream); + assert!(super::is_socket_unix(&fd, None, None, None).unwrap()); + #[cfg(not(target_vendor = "apple"))] + assert!(super::is_socket_unix(&fd, Some(super::SockType::Stream), + Some(false), None).unwrap()); + #[cfg(target_vendor = "apple")] + assert!(super::is_socket_unix(&fd, Some(super::SockType::Stream), None, None).unwrap()); + + let fd = open_file(); + assert!(!super::is_socket_unix(&fd, None, None, None).unwrap()); + } +} -- cgit v1.2.3