Merge pull request #125 from djmitche/issue121

Allow disabling automatic creation of clients
This commit is contained in:
Dustin J. Mitchell 2025-07-12 09:48:44 -04:00 committed by GitHub
commit 87d1d026b3
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 181 additions and 54 deletions

View file

@ -85,11 +85,20 @@ list of values.
The `--data-dir` option specifies where the server should store its data. This The `--data-dir` option specifies where the server should store its data. This
value can be specified in the environment variable `DATA_DIR`. value can be specified in the environment variable `DATA_DIR`.
By default, the server allows all client IDs. To limit the accepted client IDs, By default, the server will allow all clients and create them in the database
specify them in the environment variable `CLIENT_ID`, as a comma-separated list on first contact. There are two ways to limit the clients the server will
of UUIDs. Client IDs can be specified with `--allow-client-id`, but this should interact with:
not be used on shared systems, as command line arguments are visible to all
users on the system. - To limit the accepted client IDs, specify them in the environment variable
`CLIENT_ID`, as a comma-separated list of UUIDs. Client IDs can be specified
with `--allow-client-id`, but this should not be used on shared systems, as
command line arguments are visible to all users on the system. This convenient
option is suitable for personal and small-scale deployments.
- To disable the automatic creation of clients, use the `--no-create-clients`
flag or the `CREATE_CLIENTS=false` environment variable. You are now
responsible for creating clients in the database manually, so this option is
more suitable for large scale deployments.
The server only logs errors by default. To add additional logging output, set The server only logs errors by default. To add additional logging output, set
environment variable `RUST_LOG` to `info` to get a log message for every environment variable `RUST_LOG` to `info` to get a log message for every

View file

@ -55,11 +55,11 @@ pub(crate) async fn service(
#[cfg(test)] #[cfg(test)]
mod test { mod test {
use crate::api::CLIENT_ID_HEADER;
use crate::WebServer; use crate::WebServer;
use crate::{api::CLIENT_ID_HEADER, WebConfig};
use actix_web::{http::StatusCode, test, App}; use actix_web::{http::StatusCode, test, App};
use pretty_assertions::assert_eq; use pretty_assertions::assert_eq;
use taskchampion_sync_server_core::{InMemoryStorage, Storage, NIL_VERSION_ID}; use taskchampion_sync_server_core::{InMemoryStorage, ServerConfig, Storage, NIL_VERSION_ID};
use uuid::Uuid; use uuid::Uuid;
#[actix_rt::test] #[actix_rt::test]
@ -76,11 +76,11 @@ mod test {
txn.commit()?; txn.commit()?;
} }
let server = WebServer::new(Default::default(), None, storage); let server = WebServer::new(ServerConfig::default(), WebConfig::default(), storage);
let app = App::new().configure(|sc| server.config(sc)); let app = App::new().configure(|sc| server.config(sc));
let app = test::init_service(app).await; let app = test::init_service(app).await;
let uri = format!("/v1/client/add-snapshot/{}", version_id); let uri = format!("/v1/client/add-snapshot/{version_id}");
let req = test::TestRequest::post() let req = test::TestRequest::post()
.uri(&uri) .uri(&uri)
.insert_header(("Content-Type", "application/vnd.taskchampion.snapshot")) .insert_header(("Content-Type", "application/vnd.taskchampion.snapshot"))
@ -119,12 +119,12 @@ mod test {
txn.commit().unwrap(); txn.commit().unwrap();
} }
let server = WebServer::new(Default::default(), None, storage); let server = WebServer::new(ServerConfig::default(), WebConfig::default(), storage);
let app = App::new().configure(|sc| server.config(sc)); let app = App::new().configure(|sc| server.config(sc));
let app = test::init_service(app).await; let app = test::init_service(app).await;
// add a snapshot for a nonexistent version // add a snapshot for a nonexistent version
let uri = format!("/v1/client/add-snapshot/{}", version_id); let uri = format!("/v1/client/add-snapshot/{version_id}");
let req = test::TestRequest::post() let req = test::TestRequest::post()
.uri(&uri) .uri(&uri)
.append_header(("Content-Type", "application/vnd.taskchampion.snapshot")) .append_header(("Content-Type", "application/vnd.taskchampion.snapshot"))
@ -149,11 +149,11 @@ mod test {
let client_id = Uuid::new_v4(); let client_id = Uuid::new_v4();
let version_id = Uuid::new_v4(); let version_id = Uuid::new_v4();
let storage = InMemoryStorage::new(); let storage = InMemoryStorage::new();
let server = WebServer::new(Default::default(), None, storage); let server = WebServer::new(ServerConfig::default(), WebConfig::default(), storage);
let app = App::new().configure(|sc| server.config(sc)); let app = App::new().configure(|sc| server.config(sc));
let app = test::init_service(app).await; let app = test::init_service(app).await;
let uri = format!("/v1/client/add-snapshot/{}", version_id); let uri = format!("/v1/client/add-snapshot/{version_id}");
let req = test::TestRequest::post() let req = test::TestRequest::post()
.uri(&uri) .uri(&uri)
.append_header(("Content-Type", "not/correct")) .append_header(("Content-Type", "not/correct"))
@ -169,11 +169,11 @@ mod test {
let client_id = Uuid::new_v4(); let client_id = Uuid::new_v4();
let version_id = Uuid::new_v4(); let version_id = Uuid::new_v4();
let storage = InMemoryStorage::new(); let storage = InMemoryStorage::new();
let server = WebServer::new(Default::default(), None, storage); let server = WebServer::new(ServerConfig::default(), WebConfig::default(), storage);
let app = App::new().configure(|sc| server.config(sc)); let app = App::new().configure(|sc| server.config(sc));
let app = test::init_service(app).await; let app = test::init_service(app).await;
let uri = format!("/v1/client/add-snapshot/{}", version_id); let uri = format!("/v1/client/add-snapshot/{version_id}");
let req = test::TestRequest::post() let req = test::TestRequest::post()
.uri(&uri) .uri(&uri)
.append_header(( .append_header((

View file

@ -80,7 +80,7 @@ pub(crate) async fn service(
rb.append_header((PARENT_VERSION_ID_HEADER, parent_version_id.to_string())); rb.append_header((PARENT_VERSION_ID_HEADER, parent_version_id.to_string()));
Ok(rb.finish()) Ok(rb.finish())
} }
Err(ServerError::NoSuchClient) => { Err(ServerError::NoSuchClient) if server_state.web_config.create_clients => {
// Create a new client and repeat the `add_version` call. // Create a new client and repeat the `add_version` call.
let mut txn = server_state let mut txn = server_state
.server .server
@ -97,11 +97,11 @@ pub(crate) async fn service(
#[cfg(test)] #[cfg(test)]
mod test { mod test {
use crate::api::CLIENT_ID_HEADER;
use crate::WebServer; use crate::WebServer;
use crate::{api::CLIENT_ID_HEADER, WebConfig};
use actix_web::{http::StatusCode, test, App}; use actix_web::{http::StatusCode, test, App};
use pretty_assertions::assert_eq; use pretty_assertions::assert_eq;
use taskchampion_sync_server_core::{InMemoryStorage, Storage}; use taskchampion_sync_server_core::{InMemoryStorage, ServerConfig, Storage};
use uuid::Uuid; use uuid::Uuid;
#[actix_rt::test] #[actix_rt::test]
@ -118,11 +118,11 @@ mod test {
txn.commit().unwrap(); txn.commit().unwrap();
} }
let server = WebServer::new(Default::default(), None, storage); let server = WebServer::new(ServerConfig::default(), WebConfig::default(), storage);
let app = App::new().configure(|sc| server.config(sc)); let app = App::new().configure(|sc| server.config(sc));
let app = test::init_service(app).await; let app = test::init_service(app).await;
let uri = format!("/v1/client/add-version/{}", parent_version_id); let uri = format!("/v1/client/add-version/{parent_version_id}");
let req = test::TestRequest::post() let req = test::TestRequest::post()
.uri(&uri) .uri(&uri)
.append_header(( .append_header((
@ -152,11 +152,15 @@ mod test {
let client_id = Uuid::new_v4(); let client_id = Uuid::new_v4();
let version_id = Uuid::new_v4(); let version_id = Uuid::new_v4();
let parent_version_id = Uuid::new_v4(); let parent_version_id = Uuid::new_v4();
let server = WebServer::new(Default::default(), None, InMemoryStorage::new()); let server = WebServer::new(
ServerConfig::default(),
WebConfig::default(),
InMemoryStorage::new(),
);
let app = App::new().configure(|sc| server.config(sc)); let app = App::new().configure(|sc| server.config(sc));
let app = test::init_service(app).await; let app = test::init_service(app).await;
let uri = format!("/v1/client/add-version/{}", parent_version_id); let uri = format!("/v1/client/add-version/{parent_version_id}");
let req = test::TestRequest::post() let req = test::TestRequest::post()
.uri(&uri) .uri(&uri)
.append_header(( .append_header((
@ -190,6 +194,36 @@ mod test {
} }
} }
#[actix_rt::test]
async fn test_auto_add_client_disabled() {
let client_id = Uuid::new_v4();
let parent_version_id = Uuid::new_v4();
let server = WebServer::new(
ServerConfig::default(),
WebConfig {
create_clients: false,
..WebConfig::default()
},
InMemoryStorage::new(),
);
let app = App::new().configure(|sc| server.config(sc));
let app = test::init_service(app).await;
let uri = format!("/v1/client/add-version/{parent_version_id}");
let req = test::TestRequest::post()
.uri(&uri)
.append_header((
"Content-Type",
"application/vnd.taskchampion.history-segment",
))
.append_header((CLIENT_ID_HEADER, client_id.to_string()))
.set_payload(b"abcd".to_vec())
.to_request();
let resp = test::call_service(&app, req).await;
// Client is not added, and returns 404.
assert_eq!(resp.status(), StatusCode::NOT_FOUND);
}
#[actix_rt::test] #[actix_rt::test]
async fn test_conflict() { async fn test_conflict() {
let client_id = Uuid::new_v4(); let client_id = Uuid::new_v4();
@ -204,11 +238,11 @@ mod test {
txn.commit().unwrap(); txn.commit().unwrap();
} }
let server = WebServer::new(Default::default(), None, storage); let server = WebServer::new(ServerConfig::default(), WebConfig::default(), storage);
let app = App::new().configure(|sc| server.config(sc)); let app = App::new().configure(|sc| server.config(sc));
let app = test::init_service(app).await; let app = test::init_service(app).await;
let uri = format!("/v1/client/add-version/{}", parent_version_id); let uri = format!("/v1/client/add-version/{parent_version_id}");
let req = test::TestRequest::post() let req = test::TestRequest::post()
.uri(&uri) .uri(&uri)
.append_header(( .append_header((
@ -232,11 +266,11 @@ mod test {
let client_id = Uuid::new_v4(); let client_id = Uuid::new_v4();
let parent_version_id = Uuid::new_v4(); let parent_version_id = Uuid::new_v4();
let storage = InMemoryStorage::new(); let storage = InMemoryStorage::new();
let server = WebServer::new(Default::default(), None, storage); let server = WebServer::new(ServerConfig::default(), WebConfig::default(), storage);
let app = App::new().configure(|sc| server.config(sc)); let app = App::new().configure(|sc| server.config(sc));
let app = test::init_service(app).await; let app = test::init_service(app).await;
let uri = format!("/v1/client/add-version/{}", parent_version_id); let uri = format!("/v1/client/add-version/{parent_version_id}");
let req = test::TestRequest::post() let req = test::TestRequest::post()
.uri(&uri) .uri(&uri)
.append_header(("Content-Type", "not/correct")) .append_header(("Content-Type", "not/correct"))
@ -252,11 +286,11 @@ mod test {
let client_id = Uuid::new_v4(); let client_id = Uuid::new_v4();
let parent_version_id = Uuid::new_v4(); let parent_version_id = Uuid::new_v4();
let storage = InMemoryStorage::new(); let storage = InMemoryStorage::new();
let server = WebServer::new(Default::default(), None, storage); let server = WebServer::new(ServerConfig::default(), WebConfig::default(), storage);
let app = App::new().configure(|sc| server.config(sc)); let app = App::new().configure(|sc| server.config(sc));
let app = test::init_service(app).await; let app = test::init_service(app).await;
let uri = format!("/v1/client/add-version/{}", parent_version_id); let uri = format!("/v1/client/add-version/{parent_version_id}");
let req = test::TestRequest::post() let req = test::TestRequest::post()
.uri(&uri) .uri(&uri)
.append_header(( .append_header((

View file

@ -48,11 +48,11 @@ pub(crate) async fn service(
#[cfg(test)] #[cfg(test)]
mod test { mod test {
use crate::api::CLIENT_ID_HEADER;
use crate::WebServer; use crate::WebServer;
use crate::{api::CLIENT_ID_HEADER, WebConfig};
use actix_web::{http::StatusCode, test, App}; use actix_web::{http::StatusCode, test, App};
use pretty_assertions::assert_eq; use pretty_assertions::assert_eq;
use taskchampion_sync_server_core::{InMemoryStorage, Storage, NIL_VERSION_ID}; use taskchampion_sync_server_core::{InMemoryStorage, ServerConfig, Storage, NIL_VERSION_ID};
use uuid::Uuid; use uuid::Uuid;
#[actix_rt::test] #[actix_rt::test]
@ -71,11 +71,11 @@ mod test {
txn.commit().unwrap(); txn.commit().unwrap();
} }
let server = WebServer::new(Default::default(), None, storage); let server = WebServer::new(ServerConfig::default(), WebConfig::default(), storage);
let app = App::new().configure(|sc| server.config(sc)); let app = App::new().configure(|sc| server.config(sc));
let app = test::init_service(app).await; let app = test::init_service(app).await;
let uri = format!("/v1/client/get-child-version/{}", parent_version_id); let uri = format!("/v1/client/get-child-version/{parent_version_id}");
let req = test::TestRequest::get() let req = test::TestRequest::get()
.uri(&uri) .uri(&uri)
.append_header((CLIENT_ID_HEADER, client_id.to_string())) .append_header((CLIENT_ID_HEADER, client_id.to_string()))
@ -105,11 +105,11 @@ mod test {
let client_id = Uuid::new_v4(); let client_id = Uuid::new_v4();
let parent_version_id = Uuid::new_v4(); let parent_version_id = Uuid::new_v4();
let storage = InMemoryStorage::new(); let storage = InMemoryStorage::new();
let server = WebServer::new(Default::default(), None, storage); let server = WebServer::new(ServerConfig::default(), WebConfig::default(), storage);
let app = App::new().configure(|sc| server.config(sc)); let app = App::new().configure(|sc| server.config(sc));
let app = test::init_service(app).await; let app = test::init_service(app).await;
let uri = format!("/v1/client/get-child-version/{}", parent_version_id); let uri = format!("/v1/client/get-child-version/{parent_version_id}");
let req = test::TestRequest::get() let req = test::TestRequest::get()
.uri(&uri) .uri(&uri)
.append_header((CLIENT_ID_HEADER, client_id.to_string())) .append_header((CLIENT_ID_HEADER, client_id.to_string()))
@ -134,12 +134,12 @@ mod test {
.unwrap(); .unwrap();
txn.commit().unwrap(); txn.commit().unwrap();
} }
let server = WebServer::new(Default::default(), None, storage); let server = WebServer::new(ServerConfig::default(), WebConfig::default(), storage);
let app = App::new().configure(|sc| server.config(sc)); let app = App::new().configure(|sc| server.config(sc));
let app = test::init_service(app).await; let app = test::init_service(app).await;
// the child of the nil version is the added version // the child of the nil version is the added version
let uri = format!("/v1/client/get-child-version/{}", NIL_VERSION_ID); let uri = format!("/v1/client/get-child-version/{NIL_VERSION_ID}");
let req = test::TestRequest::get() let req = test::TestRequest::get()
.uri(&uri) .uri(&uri)
.append_header((CLIENT_ID_HEADER, client_id.to_string())) .append_header((CLIENT_ID_HEADER, client_id.to_string()))
@ -168,7 +168,7 @@ mod test {
// The child of the latest version is NOT_FOUND. The tests in crate::server test more // The child of the latest version is NOT_FOUND. The tests in crate::server test more
// corner cases. // corner cases.
let uri = format!("/v1/client/get-child-version/{}", test_version_id); let uri = format!("/v1/client/get-child-version/{test_version_id}");
let req = test::TestRequest::get() let req = test::TestRequest::get()
.uri(&uri) .uri(&uri)
.append_header((CLIENT_ID_HEADER, client_id.to_string())) .append_header((CLIENT_ID_HEADER, client_id.to_string()))

View file

@ -33,12 +33,12 @@ pub(crate) async fn service(
#[cfg(test)] #[cfg(test)]
mod test { mod test {
use crate::api::CLIENT_ID_HEADER;
use crate::WebServer; use crate::WebServer;
use crate::{api::CLIENT_ID_HEADER, WebConfig};
use actix_web::{http::StatusCode, test, App}; use actix_web::{http::StatusCode, test, App};
use chrono::{TimeZone, Utc}; use chrono::{TimeZone, Utc};
use pretty_assertions::assert_eq; use pretty_assertions::assert_eq;
use taskchampion_sync_server_core::{InMemoryStorage, Snapshot, Storage}; use taskchampion_sync_server_core::{InMemoryStorage, ServerConfig, Snapshot, Storage};
use uuid::Uuid; use uuid::Uuid;
#[actix_rt::test] #[actix_rt::test]
@ -53,7 +53,7 @@ mod test {
txn.commit().unwrap(); txn.commit().unwrap();
} }
let server = WebServer::new(Default::default(), None, storage); let server = WebServer::new(ServerConfig::default(), WebConfig::default(), storage);
let app = App::new().configure(|sc| server.config(sc)); let app = App::new().configure(|sc| server.config(sc));
let app = test::init_service(app).await; let app = test::init_service(app).await;
@ -89,7 +89,7 @@ mod test {
txn.commit().unwrap(); txn.commit().unwrap();
} }
let server = WebServer::new(Default::default(), None, storage); let server = WebServer::new(ServerConfig::default(), WebConfig::default(), storage);
let app = App::new().configure(|sc| server.config(sc)); let app = App::new().configure(|sc| server.config(sc));
let app = test::init_service(app).await; let app = test::init_service(app).await;

View file

@ -1,8 +1,7 @@
use std::collections::HashSet;
use actix_web::{error, web, HttpRequest, Result, Scope}; use actix_web::{error, web, HttpRequest, Result, Scope};
use taskchampion_sync_server_core::{ClientId, Server, ServerError}; use taskchampion_sync_server_core::{ClientId, Server, ServerError};
use uuid::Uuid;
use crate::WebConfig;
mod add_snapshot; mod add_snapshot;
mod add_version; mod add_version;
@ -31,7 +30,7 @@ pub(crate) const SNAPSHOT_REQUEST_HEADER: &str = "X-Snapshot-Request";
/// The type containing a reference to the persistent state for the server /// The type containing a reference to the persistent state for the server
pub(crate) struct ServerState { pub(crate) struct ServerState {
pub(crate) server: Server, pub(crate) server: Server,
pub(crate) client_id_allowlist: Option<HashSet<Uuid>>, pub(crate) web_config: WebConfig,
} }
impl ServerState { impl ServerState {
@ -43,7 +42,7 @@ impl ServerState {
if let Some(client_id_hdr) = req.headers().get(CLIENT_ID_HEADER) { if let Some(client_id_hdr) = req.headers().get(CLIENT_ID_HEADER) {
let client_id = client_id_hdr.to_str().map_err(|_| badrequest())?; let client_id = client_id_hdr.to_str().map_err(|_| badrequest())?;
let client_id = ClientId::parse_str(client_id).map_err(|_| badrequest())?; let client_id = ClientId::parse_str(client_id).map_err(|_| badrequest())?;
if let Some(allow_list) = &self.client_id_allowlist { if let Some(allow_list) = &self.web_config.client_id_allowlist {
if !allow_list.contains(&client_id) { if !allow_list.contains(&client_id) {
return Err(error::ErrorForbidden("unknown x-client-id")); return Err(error::ErrorForbidden("unknown x-client-id"));
} }
@ -80,13 +79,17 @@ fn server_error_to_actix(err: ServerError) -> actix_web::Error {
mod test { mod test {
use super::*; use super::*;
use taskchampion_sync_server_core::InMemoryStorage; use taskchampion_sync_server_core::InMemoryStorage;
use uuid::Uuid;
#[test] #[test]
fn client_id_header_allow_all() { fn client_id_header_allow_all() {
let client_id = Uuid::new_v4(); let client_id = Uuid::new_v4();
let state = ServerState { let state = ServerState {
server: Server::new(Default::default(), InMemoryStorage::new()), server: Server::new(Default::default(), InMemoryStorage::new()),
client_id_allowlist: None, web_config: WebConfig {
client_id_allowlist: None,
create_clients: true,
},
}; };
let req = actix_web::test::TestRequest::default() let req = actix_web::test::TestRequest::default()
.insert_header((CLIENT_ID_HEADER, client_id.to_string())) .insert_header((CLIENT_ID_HEADER, client_id.to_string()))
@ -100,7 +103,10 @@ mod test {
let client_id_disallowed = Uuid::new_v4(); let client_id_disallowed = Uuid::new_v4();
let state = ServerState { let state = ServerState {
server: Server::new(Default::default(), InMemoryStorage::new()), server: Server::new(Default::default(), InMemoryStorage::new()),
client_id_allowlist: Some([client_id_ok].into()), web_config: WebConfig {
client_id_allowlist: Some([client_id_ok].into()),
create_clients: true,
},
}; };
let req = actix_web::test::TestRequest::default() let req = actix_web::test::TestRequest::default()
.insert_header((CLIENT_ID_HEADER, client_id_ok.to_string())) .insert_header((CLIENT_ID_HEADER, client_id_ok.to_string()))

View file

@ -8,7 +8,7 @@ use actix_web::{
}; };
use clap::{arg, builder::ValueParser, value_parser, ArgAction, Command}; use clap::{arg, builder::ValueParser, value_parser, ArgAction, Command};
use std::{collections::HashSet, ffi::OsString}; use std::{collections::HashSet, ffi::OsString};
use taskchampion_sync_server::WebServer; use taskchampion_sync_server::{WebConfig, WebServer};
use taskchampion_sync_server_core::ServerConfig; use taskchampion_sync_server_core::ServerConfig;
use taskchampion_sync_server_storage_sqlite::SqliteStorage; use taskchampion_sync_server_storage_sqlite::SqliteStorage;
use uuid::Uuid; use uuid::Uuid;
@ -43,6 +43,13 @@ fn command() -> Command {
.action(ArgAction::Append) .action(ArgAction::Append)
.required(false), .required(false),
) )
.arg(
arg!("create-clients": --"no-create-clients" "If a client does not exist in the database, do not create it")
.env("CREATE_CLIENTS")
.default_value("true")
.action(ArgAction::SetFalse)
.required(false),
)
.arg( .arg(
arg!(--"snapshot-versions" <NUM> "Target number of versions between snapshots") arg!(--"snapshot-versions" <NUM> "Target number of versions between snapshots")
.value_parser(value_parser!(u32)) .value_parser(value_parser!(u32))
@ -69,6 +76,7 @@ struct ServerArgs {
snapshot_versions: u32, snapshot_versions: u32,
snapshot_days: i64, snapshot_days: i64,
client_id_allowlist: Option<HashSet<Uuid>>, client_id_allowlist: Option<HashSet<Uuid>>,
create_clients: bool,
listen_addresses: Vec<String>, listen_addresses: Vec<String>,
} }
@ -81,6 +89,7 @@ impl ServerArgs {
client_id_allowlist: matches client_id_allowlist: matches
.get_many("allow-client-id") .get_many("allow-client-id")
.map(|ids| ids.copied().collect()), .map(|ids| ids.copied().collect()),
create_clients: matches.get_one("create-clients").copied().unwrap_or(true),
listen_addresses: matches listen_addresses: matches
.get_many::<String>("listen") .get_many::<String>("listen")
.unwrap() .unwrap()
@ -102,7 +111,10 @@ async fn main() -> anyhow::Result<()> {
}; };
let server = WebServer::new( let server = WebServer::new(
config, config,
server_args.client_id_allowlist, WebConfig {
client_id_allowlist: server_args.client_id_allowlist,
create_clients: server_args.create_clients,
},
SqliteStorage::new(server_args.data_dir)?, SqliteStorage::new(server_args.data_dir)?,
); );
@ -122,6 +134,8 @@ async fn main() -> anyhow::Result<()> {
#[cfg(test)] #[cfg(test)]
mod test { mod test {
#![allow(clippy::bool_assert_comparison)]
use super::*; use super::*;
use actix_web::{self, App}; use actix_web::{self, App};
use clap::ArgMatches; use clap::ArgMatches;
@ -309,9 +323,54 @@ mod test {
); );
} }
#[test]
fn command_create_clients_default() {
with_var_unset("CREATE_CLIENTS", || {
let matches = command().get_matches_from(["tss", "--listen", "localhost:8080"]);
let server_args = ServerArgs::new(matches);
assert_eq!(server_args.create_clients, true);
});
}
#[test]
fn command_create_clients_cmdline() {
with_var_unset("CREATE_CLIENTS", || {
let matches = command().get_matches_from([
"tss",
"--listen",
"localhost:8080",
"--no-create-clients",
]);
let server_args = ServerArgs::new(matches);
assert_eq!(server_args.create_clients, false);
});
}
#[test]
fn command_create_clients_env_true() {
with_vars([("CREATE_CLIENTS", Some("true"))], || {
let matches = command().get_matches_from(["tss", "--listen", "localhost:8080"]);
let server_args = ServerArgs::new(matches);
assert_eq!(server_args.create_clients, true);
});
}
#[test]
fn command_create_clients_env_false() {
with_vars([("CREATE_CLIENTS", Some("false"))], || {
let matches = command().get_matches_from(["tss", "--listen", "localhost:8080"]);
let server_args = ServerArgs::new(matches);
assert_eq!(server_args.create_clients, false);
});
}
#[actix_rt::test] #[actix_rt::test]
async fn test_index_get() { async fn test_index_get() {
let server = WebServer::new(Default::default(), None, InMemoryStorage::new()); let server = WebServer::new(
ServerConfig::default(),
WebConfig::default(),
InMemoryStorage::new(),
);
let app = App::new().configure(|sc| server.config(sc)); let app = App::new().configure(|sc| server.config(sc));
let app = actix_web::test::init_service(app).await; let app = actix_web::test::init_service(app).await;

View file

@ -19,17 +19,32 @@ pub struct WebServer {
server_state: Arc<ServerState>, server_state: Arc<ServerState>,
} }
/// Configuration for WebServer (as distinct from [`ServerConfig`]).
pub struct WebConfig {
pub client_id_allowlist: Option<HashSet<Uuid>>,
pub create_clients: bool,
}
impl Default for WebConfig {
fn default() -> Self {
Self {
client_id_allowlist: Default::default(),
create_clients: true,
}
}
}
impl WebServer { impl WebServer {
/// Create a new sync server with the given storage implementation. /// Create a new sync server with the given storage implementation.
pub fn new<ST: Storage + 'static>( pub fn new<ST: Storage + 'static>(
config: ServerConfig, config: ServerConfig,
client_id_allowlist: Option<HashSet<Uuid>>, web_config: WebConfig,
storage: ST, storage: ST,
) -> Self { ) -> Self {
Self { Self {
server_state: Arc::new(ServerState { server_state: Arc::new(ServerState {
server: Server::new(config, storage), server: Server::new(config, storage),
client_id_allowlist, web_config,
}), }),
} }
} }
@ -57,7 +72,11 @@ mod test {
#[actix_rt::test] #[actix_rt::test]
async fn test_cache_control() { async fn test_cache_control() {
let server = WebServer::new(Default::default(), None, InMemoryStorage::new()); let server = WebServer::new(
ServerConfig::default(),
WebConfig::default(),
InMemoryStorage::new(),
);
let app = App::new().configure(|sc| server.config(sc)); let app = App::new().configure(|sc| server.config(sc));
let app = test::init_service(app).await; let app = test::init_service(app).await;