From 50d028f45e0a943d9b1faad14099d5d01e03faa3 Mon Sep 17 00:00:00 2001 From: "Dustin J. Mitchell" Date: Thu, 21 Nov 2024 21:27:52 -0500 Subject: [PATCH] Support a client-id allowlist (#62) This will support setting up publicly-accessible personal servers, without also allowing anyone to create a new client. --- README.md | 13 ++++ server/src/api/add_snapshot.rs | 12 +-- server/src/api/add_version.rs | 17 ++-- server/src/api/get_child_version.rs | 12 +-- server/src/api/get_snapshot.rs | 10 +-- server/src/api/mod.rs | 75 +++++++++++++++--- server/src/bin/taskchampion-sync-server.rs | 90 ++++++++++++++++++---- server/src/lib.rs | 12 ++- 8 files changed, 188 insertions(+), 53 deletions(-) diff --git a/README.md b/README.md index e545a3c..2ff0a8c 100644 --- a/README.md +++ b/README.md @@ -19,6 +19,19 @@ It is comprised of three crates: - `taskchmpaion-sync-server-sqlite` implements an SQLite backend for the core - `taskchampion-sync-server` implements a simple HTTP server for the protocol +## Running the Server + +The server is configured with command-line options. See +`taskchampion-sync-server --help` for full details. + +The `--data-dir` option specifies where the server should store its data, and +`--port` gives the port on which the HTTP server runs. The server does not +implement TLS; for public deployments, the recommendation is to use a reverse +proxy such as Nginx, haproxy, or Apache httpd. + +By default, the server allows all client IDs. To limit the accepted client IDs, +such as when running a personal server, use `--allow-client-id `. + ## Installation ### As container diff --git a/server/src/api/add_snapshot.rs b/server/src/api/add_snapshot.rs index b5d93cf..783bd6c 100644 --- a/server/src/api/add_snapshot.rs +++ b/server/src/api/add_snapshot.rs @@ -1,4 +1,4 @@ -use crate::api::{client_id_header, server_error_to_actix, ServerState, SNAPSHOT_CONTENT_TYPE}; +use crate::api::{server_error_to_actix, ServerState, SNAPSHOT_CONTENT_TYPE}; use actix_web::{error, post, web, HttpMessage, HttpRequest, HttpResponse, Result}; use futures::StreamExt; use std::sync::Arc; @@ -29,7 +29,7 @@ pub(crate) async fn service( return Err(error::ErrorBadRequest("Bad content-type")); } - let client_id = client_id_header(&req)?; + let client_id = server_state.client_id_header(&req)?; // read the body in its entirety let mut body = web::BytesMut::new(); @@ -75,7 +75,7 @@ mod test { txn.add_version(client_id, version_id, NIL_VERSION_ID, vec![])?; } - let server = WebServer::new(Default::default(), storage); + let server = WebServer::new(Default::default(), None, storage); let app = App::new().configure(|sc| server.config(sc)); let app = test::init_service(app).await; @@ -117,7 +117,7 @@ mod test { txn.new_client(client_id, NIL_VERSION_ID).unwrap(); } - let server = WebServer::new(Default::default(), storage); + let server = WebServer::new(Default::default(), None, storage); let app = App::new().configure(|sc| server.config(sc)); let app = test::init_service(app).await; @@ -147,7 +147,7 @@ mod test { let client_id = Uuid::new_v4(); let version_id = Uuid::new_v4(); let storage = InMemoryStorage::new(); - let server = WebServer::new(Default::default(), storage); + let server = WebServer::new(Default::default(), None, storage); let app = App::new().configure(|sc| server.config(sc)); let app = test::init_service(app).await; @@ -167,7 +167,7 @@ mod test { let client_id = Uuid::new_v4(); let version_id = Uuid::new_v4(); let storage = InMemoryStorage::new(); - let server = WebServer::new(Default::default(), storage); + let server = WebServer::new(Default::default(), None, storage); let app = App::new().configure(|sc| server.config(sc)); let app = test::init_service(app).await; diff --git a/server/src/api/add_version.rs b/server/src/api/add_version.rs index 82dd89c..12d80e3 100644 --- a/server/src/api/add_version.rs +++ b/server/src/api/add_version.rs @@ -1,7 +1,6 @@ use crate::api::{ - client_id_header, failure_to_ise, server_error_to_actix, ServerState, - HISTORY_SEGMENT_CONTENT_TYPE, PARENT_VERSION_ID_HEADER, SNAPSHOT_REQUEST_HEADER, - VERSION_ID_HEADER, + failure_to_ise, server_error_to_actix, ServerState, HISTORY_SEGMENT_CONTENT_TYPE, + PARENT_VERSION_ID_HEADER, SNAPSHOT_REQUEST_HEADER, VERSION_ID_HEADER, }; use actix_web::{error, post, web, HttpMessage, HttpRequest, HttpResponse, Result}; use futures::StreamExt; @@ -40,7 +39,7 @@ pub(crate) async fn service( return Err(error::ErrorBadRequest("Bad content-type")); } - let client_id = client_id_header(&req)?; + let client_id = server_state.client_id_header(&req)?; // read the body in its entirety let mut body = web::BytesMut::new(); @@ -116,7 +115,7 @@ mod test { txn.new_client(client_id, Uuid::nil()).unwrap(); } - let server = WebServer::new(Default::default(), storage); + let server = WebServer::new(Default::default(), None, storage); let app = App::new().configure(|sc| server.config(sc)); let app = test::init_service(app).await; @@ -150,7 +149,7 @@ mod test { let client_id = Uuid::new_v4(); let version_id = Uuid::new_v4(); let parent_version_id = Uuid::new_v4(); - let server = WebServer::new(Default::default(), InMemoryStorage::new()); + let server = WebServer::new(Default::default(), None, InMemoryStorage::new()); let app = App::new().configure(|sc| server.config(sc)); let app = test::init_service(app).await; @@ -201,7 +200,7 @@ mod test { txn.new_client(client_id, version_id).unwrap(); } - let server = WebServer::new(Default::default(), storage); + let server = WebServer::new(Default::default(), None, storage); let app = App::new().configure(|sc| server.config(sc)); let app = test::init_service(app).await; @@ -229,7 +228,7 @@ mod test { let client_id = Uuid::new_v4(); let parent_version_id = Uuid::new_v4(); let storage = InMemoryStorage::new(); - let server = WebServer::new(Default::default(), storage); + let server = WebServer::new(Default::default(), None, storage); let app = App::new().configure(|sc| server.config(sc)); let app = test::init_service(app).await; @@ -249,7 +248,7 @@ mod test { let client_id = Uuid::new_v4(); let parent_version_id = Uuid::new_v4(); let storage = InMemoryStorage::new(); - let server = WebServer::new(Default::default(), storage); + let server = WebServer::new(Default::default(), None, storage); let app = App::new().configure(|sc| server.config(sc)); let app = test::init_service(app).await; diff --git a/server/src/api/get_child_version.rs b/server/src/api/get_child_version.rs index f1b0a0a..0779415 100644 --- a/server/src/api/get_child_version.rs +++ b/server/src/api/get_child_version.rs @@ -1,6 +1,6 @@ use crate::api::{ - client_id_header, server_error_to_actix, ServerState, HISTORY_SEGMENT_CONTENT_TYPE, - PARENT_VERSION_ID_HEADER, VERSION_ID_HEADER, + server_error_to_actix, ServerState, HISTORY_SEGMENT_CONTENT_TYPE, PARENT_VERSION_ID_HEADER, + VERSION_ID_HEADER, }; use actix_web::{error, get, web, HttpRequest, HttpResponse, Result}; use std::sync::Arc; @@ -21,7 +21,7 @@ pub(crate) async fn service( path: web::Path, ) -> Result { let parent_version_id = path.into_inner(); - let client_id = client_id_header(&req)?; + let client_id = server_state.client_id_header(&req)?; return match server_state .server @@ -70,7 +70,7 @@ mod test { .unwrap(); } - let server = WebServer::new(Default::default(), storage); + let server = WebServer::new(Default::default(), None, storage); let app = App::new().configure(|sc| server.config(sc)); let app = test::init_service(app).await; @@ -104,7 +104,7 @@ mod test { let client_id = Uuid::new_v4(); let parent_version_id = Uuid::new_v4(); let storage = InMemoryStorage::new(); - let server = WebServer::new(Default::default(), storage); + let server = WebServer::new(Default::default(), None, storage); let app = App::new().configure(|sc| server.config(sc)); let app = test::init_service(app).await; @@ -132,7 +132,7 @@ mod test { txn.add_version(client_id, test_version_id, NIL_VERSION_ID, b"vers".to_vec()) .unwrap(); } - let server = WebServer::new(Default::default(), storage); + let server = WebServer::new(Default::default(), None, storage); let app = App::new().configure(|sc| server.config(sc)); let app = test::init_service(app).await; diff --git a/server/src/api/get_snapshot.rs b/server/src/api/get_snapshot.rs index 47a728d..66b8a77 100644 --- a/server/src/api/get_snapshot.rs +++ b/server/src/api/get_snapshot.rs @@ -1,6 +1,4 @@ -use crate::api::{ - client_id_header, server_error_to_actix, ServerState, SNAPSHOT_CONTENT_TYPE, VERSION_ID_HEADER, -}; +use crate::api::{server_error_to_actix, ServerState, SNAPSHOT_CONTENT_TYPE, VERSION_ID_HEADER}; use actix_web::{error, get, web, HttpRequest, HttpResponse, Result}; use std::sync::Arc; @@ -17,7 +15,7 @@ pub(crate) async fn service( req: HttpRequest, server_state: web::Data>, ) -> Result { - let client_id = client_id_header(&req)?; + let client_id = server_state.client_id_header(&req)?; if let Some((version_id, data)) = server_state .server @@ -54,7 +52,7 @@ mod test { txn.new_client(client_id, Uuid::new_v4()).unwrap(); } - let server = WebServer::new(Default::default(), storage); + let server = WebServer::new(Default::default(), None, storage); let app = App::new().configure(|sc| server.config(sc)); let app = test::init_service(app).await; @@ -90,7 +88,7 @@ mod test { .unwrap(); } - let server = WebServer::new(Default::default(), storage); + let server = WebServer::new(Default::default(), None, storage); let app = App::new().configure(|sc| server.config(sc)); let app = test::init_service(app).await; diff --git a/server/src/api/mod.rs b/server/src/api/mod.rs index 25812da..5ffb18e 100644 --- a/server/src/api/mod.rs +++ b/server/src/api/mod.rs @@ -1,5 +1,8 @@ +use std::collections::HashSet; + use actix_web::{error, web, HttpRequest, Result, Scope}; use taskchampion_sync_server_core::{ClientId, Server, ServerError}; +use uuid::Uuid; mod add_snapshot; mod add_version; @@ -28,6 +31,28 @@ pub(crate) const SNAPSHOT_REQUEST_HEADER: &str = "X-Snapshot-Request"; /// The type containing a reference to the persistent state for the server pub(crate) struct ServerState { pub(crate) server: Server, + pub(crate) client_id_allowlist: Option>, +} + +impl ServerState { + /// Get the client id + fn client_id_header(&self, req: &HttpRequest) -> Result { + fn badrequest() -> error::Error { + error::ErrorBadRequest("bad x-client-id") + } + if let Some(client_id_hdr) = req.headers().get(CLIENT_ID_HEADER) { + let client_id = client_id_hdr.to_str().map_err(|_| badrequest())?; + let client_id = ClientId::parse_str(client_id).map_err(|_| badrequest())?; + if let Some(allow_list) = &self.client_id_allowlist { + if !allow_list.contains(&client_id) { + return Err(error::ErrorForbidden("unknown x-client-id")); + } + } + Ok(client_id) + } else { + Err(badrequest()) + } + } } pub(crate) fn api_scope() -> Scope { @@ -51,16 +76,46 @@ fn server_error_to_actix(err: ServerError) -> actix_web::Error { } } -/// Get the client id -fn client_id_header(req: &HttpRequest) -> Result { - fn badrequest() -> error::Error { - error::ErrorBadRequest("bad x-client-id") +#[cfg(test)] +mod test { + use super::*; + use taskchampion_sync_server_core::InMemoryStorage; + + #[test] + fn client_id_header_allow_all() { + let client_id = Uuid::new_v4(); + let state = ServerState { + server: Server::new(Default::default(), InMemoryStorage::new()), + client_id_allowlist: None, + }; + let req = actix_web::test::TestRequest::default() + .insert_header((CLIENT_ID_HEADER, client_id.to_string())) + .to_http_request(); + assert_eq!(state.client_id_header(&req).unwrap(), client_id); } - if let Some(client_id_hdr) = req.headers().get(CLIENT_ID_HEADER) { - let client_id = client_id_hdr.to_str().map_err(|_| badrequest())?; - let client_id = ClientId::parse_str(client_id).map_err(|_| badrequest())?; - Ok(client_id) - } else { - Err(badrequest()) + + #[test] + fn client_id_header_allow_list() { + let client_id_ok = Uuid::new_v4(); + let client_id_disallowed = Uuid::new_v4(); + let state = ServerState { + server: Server::new(Default::default(), InMemoryStorage::new()), + client_id_allowlist: Some([client_id_ok].into()), + }; + let req = actix_web::test::TestRequest::default() + .insert_header((CLIENT_ID_HEADER, client_id_ok.to_string())) + .to_http_request(); + assert_eq!(state.client_id_header(&req).unwrap(), client_id_ok); + let req = actix_web::test::TestRequest::default() + .insert_header((CLIENT_ID_HEADER, client_id_disallowed.to_string())) + .to_http_request(); + assert_eq!( + state + .client_id_header(&req) + .unwrap_err() + .as_response_error() + .status_code(), + 403 + ); } } diff --git a/server/src/bin/taskchampion-sync-server.rs b/server/src/bin/taskchampion-sync-server.rs index 1057b45..ecc1046 100644 --- a/server/src/bin/taskchampion-sync-server.rs +++ b/server/src/bin/taskchampion-sync-server.rs @@ -1,19 +1,18 @@ #![deny(clippy::all)] use actix_web::{middleware::Logger, App, HttpServer}; -use clap::{arg, builder::ValueParser, value_parser, Command}; -use std::ffi::OsString; +use clap::{arg, builder::ValueParser, value_parser, ArgAction, Command}; +use std::{collections::HashSet, ffi::OsString}; use taskchampion_sync_server::WebServer; use taskchampion_sync_server_core::ServerConfig; use taskchampion_sync_server_storage_sqlite::SqliteStorage; +use uuid::Uuid; -#[actix_web::main] -async fn main() -> anyhow::Result<()> { - env_logger::init(); +fn command() -> Command { let defaults = ServerConfig::default(); let default_snapshot_versions = defaults.snapshot_versions.to_string(); let default_snapshot_days = defaults.snapshot_days.to_string(); - let matches = Command::new("taskchampion-sync-server") + Command::new("taskchampion-sync-server") .version(env!("CARGO_PKG_VERSION")) .about("Server for TaskChampion") .arg( @@ -27,6 +26,12 @@ async fn main() -> anyhow::Result<()> { .value_parser(ValueParser::os_string()) .default_value("/var/lib/taskchampion-sync-server"), ) + .arg( + arg!(-C --"allow-client-id" "Client IDs to allow (can be repeated; if not specified, all clients are allowed)") + .value_parser(value_parser!(Uuid)) + .action(ArgAction::Append) + .required(false), + ) .arg( arg!(--"snapshot-versions" "Target number of versions between snapshots") .value_parser(value_parser!(u32)) @@ -37,18 +42,26 @@ async fn main() -> anyhow::Result<()> { .value_parser(value_parser!(i64)) .default_value(default_snapshot_days), ) - .get_matches(); +} + +#[actix_web::main] +async fn main() -> anyhow::Result<()> { + env_logger::init(); + let matches = command().get_matches(); let data_dir: &OsString = matches.get_one("data-dir").unwrap(); let port: usize = *matches.get_one("port").unwrap(); let snapshot_versions: u32 = *matches.get_one("snapshot-versions").unwrap(); let snapshot_days: i64 = *matches.get_one("snapshot-days").unwrap(); + let client_id_allowlist: Option> = matches + .get_many("allow-client-id") + .map(|ids| ids.copied().collect()); let config = ServerConfig { snapshot_days, snapshot_versions, }; - let server = WebServer::new(config, SqliteStorage::new(data_dir)?); + let server = WebServer::new(config, client_id_allowlist, SqliteStorage::new(data_dir)?); log::info!("Serving on port {}", port); HttpServer::new(move || { @@ -65,17 +78,68 @@ async fn main() -> anyhow::Result<()> { #[cfg(test)] mod test { use super::*; - use actix_web::{test, App}; + use actix_web::{self, App}; + use clap::ArgMatches; use taskchampion_sync_server_core::InMemoryStorage; + /// Get the list of allowed client IDs + fn allowed(matches: &ArgMatches) -> Option> { + matches + .get_many::("allow-client-id") + .map(|ids| ids.copied().collect::>()) + } + + #[test] + fn command_allowed_client_ids_none() { + let matches = command().get_matches_from(["tss"]); + assert_eq!(allowed(&matches), None); + } + + #[test] + fn command_allowed_client_ids_one() { + let matches = + command().get_matches_from(["tss", "-C", "711d5cf3-0cf0-4eb8-9eca-6f7f220638c0"]); + assert_eq!( + allowed(&matches), + Some(vec![Uuid::parse_str( + "711d5cf3-0cf0-4eb8-9eca-6f7f220638c0" + ) + .unwrap()]) + ); + } + + #[test] + fn command_allowed_client_ids_two() { + let matches = command().get_matches_from([ + "tss", + "-C", + "711d5cf3-0cf0-4eb8-9eca-6f7f220638c0", + "-C", + "bbaf4b61-344a-4a39-a19e-8caa0669b353", + ]); + assert_eq!( + allowed(&matches), + Some(vec![ + Uuid::parse_str("711d5cf3-0cf0-4eb8-9eca-6f7f220638c0").unwrap(), + Uuid::parse_str("bbaf4b61-344a-4a39-a19e-8caa0669b353").unwrap() + ]) + ); + } + + #[test] + fn command_data_dir() { + let matches = command().get_matches_from(["tss", "--data-dir", "/foo/bar"]); + assert_eq!(matches.get_one::("data-dir").unwrap(), "/foo/bar"); + } + #[actix_rt::test] async fn test_index_get() { - let server = WebServer::new(Default::default(), InMemoryStorage::new()); + let server = WebServer::new(Default::default(), None, InMemoryStorage::new()); let app = App::new().configure(|sc| server.config(sc)); - let app = test::init_service(app).await; + let app = actix_web::test::init_service(app).await; - let req = test::TestRequest::get().uri("/").to_request(); - let resp = test::call_service(&app, req).await; + let req = actix_web::test::TestRequest::get().uri("/").to_request(); + let resp = actix_web::test::call_service(&app, req).await; assert!(resp.status().is_success()); } } diff --git a/server/src/lib.rs b/server/src/lib.rs index 2d81e58..13827a5 100644 --- a/server/src/lib.rs +++ b/server/src/lib.rs @@ -4,8 +4,9 @@ mod api; use actix_web::{get, middleware, web, Responder}; use api::{api_scope, ServerState}; -use std::sync::Arc; +use std::{collections::HashSet, sync::Arc}; use taskchampion_sync_server_core::{Server, ServerConfig, Storage}; +use uuid::Uuid; #[get("/")] async fn index() -> impl Responder { @@ -20,10 +21,15 @@ pub struct WebServer { impl WebServer { /// Create a new sync server with the given storage implementation. - pub fn new(config: ServerConfig, storage: ST) -> Self { + pub fn new( + config: ServerConfig, + client_id_allowlist: Option>, + storage: ST, + ) -> Self { Self { server_state: Arc::new(ServerState { server: Server::new(config, storage), + client_id_allowlist, }), } } @@ -51,7 +57,7 @@ mod test { #[actix_rt::test] async fn test_cache_control() { - let server = WebServer::new(Default::default(), InMemoryStorage::new()); + let server = WebServer::new(Default::default(), None, InMemoryStorage::new()); let app = App::new().configure(|sc| server.config(sc)); let app = test::init_service(app).await;