diff --git a/core/src/inmemory.rs b/core/src/inmemory.rs index 23cea4a..9ed0f0b 100644 --- a/core/src/inmemory.rs +++ b/core/src/inmemory.rs @@ -21,7 +21,7 @@ struct Inner { /// /// This is not for production use, but supports testing of sync server implementations. /// -/// NOTE: this does not implement transaction rollback. +/// NOTE: this panics on transaction rollback, as it is just for testing. pub struct InMemoryStorage(Mutex); impl InMemoryStorage { @@ -36,30 +36,39 @@ impl InMemoryStorage { } } -struct InnerTxn<'a>(MutexGuard<'a, Inner>); +struct InnerTxn<'a> { + guard: MutexGuard<'a, Inner>, + written: bool, + committed: bool, +} impl Storage for InMemoryStorage { - fn txn<'a>(&'a self) -> anyhow::Result> { - Ok(Box::new(InnerTxn(self.0.lock().expect("poisoned lock")))) + fn txn(&self) -> anyhow::Result> { + Ok(Box::new(InnerTxn { + guard: self.0.lock().expect("poisoned lock"), + written: false, + committed: false, + })) } } impl<'a> StorageTxn for InnerTxn<'a> { fn get_client(&mut self, client_id: Uuid) -> anyhow::Result> { - Ok(self.0.clients.get(&client_id).cloned()) + Ok(self.guard.clients.get(&client_id).cloned()) } fn new_client(&mut self, client_id: Uuid, latest_version_id: Uuid) -> anyhow::Result<()> { - if self.0.clients.contains_key(&client_id) { + if self.guard.clients.contains_key(&client_id) { return Err(anyhow::anyhow!("Client {} already exists", client_id)); } - self.0.clients.insert( + self.guard.clients.insert( client_id, Client { latest_version_id, snapshot: None, }, ); + self.written = true; Ok(()) } @@ -70,12 +79,13 @@ impl<'a> StorageTxn for InnerTxn<'a> { data: Vec, ) -> anyhow::Result<()> { let client = self - .0 + .guard .clients .get_mut(&client_id) .ok_or_else(|| anyhow::anyhow!("no such client"))?; client.snapshot = Some(snapshot); - self.0.snapshots.insert(client_id, data); + self.guard.snapshots.insert(client_id, data); + self.written = true; Ok(()) } @@ -85,12 +95,12 @@ impl<'a> StorageTxn for InnerTxn<'a> { version_id: Uuid, ) -> anyhow::Result>> { // sanity check - let client = self.0.clients.get(&client_id); + let client = self.guard.clients.get(&client_id); let client = client.ok_or_else(|| anyhow::anyhow!("no such client"))?; if Some(&version_id) != client.snapshot.as_ref().map(|snap| &snap.version_id) { return Err(anyhow::anyhow!("unexpected snapshot_version_id")); } - Ok(self.0.snapshots.get(&client_id).cloned()) + Ok(self.guard.snapshots.get(&client_id).cloned()) } fn get_version_by_parent( @@ -98,9 +108,9 @@ impl<'a> StorageTxn for InnerTxn<'a> { client_id: Uuid, parent_version_id: Uuid, ) -> anyhow::Result> { - if let Some(parent_version_id) = self.0.children.get(&(client_id, parent_version_id)) { + if let Some(parent_version_id) = self.guard.children.get(&(client_id, parent_version_id)) { Ok(self - .0 + .guard .versions .get(&(client_id, *parent_version_id)) .cloned()) @@ -114,7 +124,7 @@ impl<'a> StorageTxn for InnerTxn<'a> { client_id: Uuid, version_id: Uuid, ) -> anyhow::Result> { - Ok(self.0.versions.get(&(client_id, version_id)).cloned()) + Ok(self.guard.versions.get(&(client_id, version_id)).cloned()) } fn add_version( @@ -131,7 +141,7 @@ impl<'a> StorageTxn for InnerTxn<'a> { history_segment, }; - if let Some(client) = self.0.clients.get_mut(&client_id) { + if let Some(client) = self.guard.clients.get_mut(&client_id) { client.latest_version_id = version_id; if let Some(ref mut snap) = client.snapshot { snap.versions_since += 1; @@ -140,19 +150,29 @@ impl<'a> StorageTxn for InnerTxn<'a> { return Err(anyhow::anyhow!("Client {} does not exist", client_id)); } - self.0 + self.guard .children .insert((client_id, parent_version_id), version_id); - self.0.versions.insert((client_id, version_id), version); + self.guard.versions.insert((client_id, version_id), version); + self.written = true; Ok(()) } fn commit(&mut self) -> anyhow::Result<()> { + self.committed = true; Ok(()) } } +impl<'a> Drop for InnerTxn<'a> { + fn drop(&mut self) { + if self.written && !self.committed { + panic!("Uncommitted InMemoryStorage transaction dropped without commiting"); + } + } +} + #[cfg(test)] mod test { use super::*; @@ -198,6 +218,7 @@ mod test { assert_eq!(client.latest_version_id, latest_version_id); assert_eq!(client.snapshot.unwrap(), snap); + txn.commit()?; Ok(()) } @@ -242,6 +263,7 @@ mod test { let version = txn.get_version(client_id, version_id)?.unwrap(); assert_eq!(version, expected); + txn.commit()?; Ok(()) } @@ -284,6 +306,7 @@ mod test { // check that mismatched version is detected assert!(txn.get_snapshot_data(client_id, Uuid::new_v4()).is_err()); + txn.commit()?; Ok(()) } } diff --git a/core/src/server.rs b/core/src/server.rs index ea871a5..c22fa31 100644 --- a/core/src/server.rs +++ b/core/src/server.rs @@ -307,7 +307,12 @@ mod test { { let _ = env_logger::builder().is_test(true).try_init(); let storage = InMemoryStorage::new(); - let res = init(storage.txn()?.as_mut())?; + let res; + { + let mut txn = storage.txn()?; + res = init(txn.as_mut())?; + txn.commit()?; + } Ok((Server::new(ServerConfig::default(), storage), res)) } diff --git a/core/src/storage.rs b/core/src/storage.rs index 7845dd6..e0c9621 100644 --- a/core/src/storage.rs +++ b/core/src/storage.rs @@ -36,8 +36,11 @@ pub struct Version { /// A transaction in the storage backend. /// /// Transactions must be sequentially consistent. That is, the results of transactions performed -/// in storage must be as if each were executed sequentially in some order. In particular, the -/// `Client.latest_version` must not change between a call to `get_client` and `add_version`. +/// in storage must be as if each were executed sequentially in some order. In particular, +/// un-committed changes must not be read by another transaction. +/// +/// Changes in a transaction that is dropped without calling `commit` must not appear in any other +/// transaction. pub trait StorageTxn { /// Get information about the given client fn get_client(&mut self, client_id: Uuid) -> anyhow::Result>; @@ -92,5 +95,5 @@ pub trait StorageTxn { /// [`crate::storage::StorageTxn`] trait. pub trait Storage: Send + Sync { /// Begin a transaction - fn txn<'a>(&'a self) -> anyhow::Result>; + fn txn(&self) -> anyhow::Result>; } diff --git a/server/src/api/add_snapshot.rs b/server/src/api/add_snapshot.rs index 783bd6c..1769ca7 100644 --- a/server/src/api/add_snapshot.rs +++ b/server/src/api/add_snapshot.rs @@ -73,6 +73,7 @@ mod test { let mut txn = storage.txn().unwrap(); txn.new_client(client_id, version_id).unwrap(); txn.add_version(client_id, version_id, NIL_VERSION_ID, vec![])?; + txn.commit()?; } let server = WebServer::new(Default::default(), None, storage); @@ -115,6 +116,7 @@ mod test { { let mut txn = storage.txn().unwrap(); txn.new_client(client_id, NIL_VERSION_ID).unwrap(); + txn.commit().unwrap(); } let server = WebServer::new(Default::default(), None, storage); diff --git a/server/src/api/add_version.rs b/server/src/api/add_version.rs index 12d80e3..d6ac245 100644 --- a/server/src/api/add_version.rs +++ b/server/src/api/add_version.rs @@ -113,6 +113,7 @@ mod test { { let mut txn = storage.txn().unwrap(); txn.new_client(client_id, Uuid::nil()).unwrap(); + txn.commit().unwrap(); } let server = WebServer::new(Default::default(), None, storage); @@ -198,6 +199,7 @@ mod test { { let mut txn = storage.txn().unwrap(); txn.new_client(client_id, version_id).unwrap(); + txn.commit().unwrap(); } let server = WebServer::new(Default::default(), None, storage); diff --git a/server/src/api/get_child_version.rs b/server/src/api/get_child_version.rs index 0779415..4cf53ef 100644 --- a/server/src/api/get_child_version.rs +++ b/server/src/api/get_child_version.rs @@ -68,6 +68,7 @@ mod test { txn.new_client(client_id, Uuid::new_v4()).unwrap(); txn.add_version(client_id, version_id, parent_version_id, b"abcd".to_vec()) .unwrap(); + txn.commit().unwrap(); } let server = WebServer::new(Default::default(), None, storage); @@ -131,6 +132,7 @@ mod test { txn.new_client(client_id, Uuid::new_v4()).unwrap(); txn.add_version(client_id, test_version_id, NIL_VERSION_ID, b"vers".to_vec()) .unwrap(); + txn.commit().unwrap(); } let server = WebServer::new(Default::default(), None, storage); let app = App::new().configure(|sc| server.config(sc)); diff --git a/server/src/api/get_snapshot.rs b/server/src/api/get_snapshot.rs index 66b8a77..6eff71a 100644 --- a/server/src/api/get_snapshot.rs +++ b/server/src/api/get_snapshot.rs @@ -50,6 +50,7 @@ mod test { { let mut txn = storage.txn().unwrap(); txn.new_client(client_id, Uuid::new_v4()).unwrap(); + txn.commit().unwrap(); } let server = WebServer::new(Default::default(), None, storage); @@ -86,6 +87,7 @@ mod test { snapshot_data.clone(), ) .unwrap(); + txn.commit().unwrap(); } let server = WebServer::new(Default::default(), None, storage); diff --git a/sqlite/src/lib.rs b/sqlite/src/lib.rs index edd0be1..e64c0c5 100644 --- a/sqlite/src/lib.rs +++ b/sqlite/src/lib.rs @@ -7,12 +7,6 @@ use std::path::Path; use taskchampion_sync_server_core::{Client, Snapshot, Storage, StorageTxn, Version}; use uuid::Uuid; -#[derive(Debug, thiserror::Error)] -enum SqliteError { - #[error("Failed to create SQLite transaction")] - CreateTransactionFailed, -} - /// Newtype to allow implementing `FromSql` for foreign `uuid::Uuid` struct StoredUuid(Uuid); @@ -34,6 +28,9 @@ impl ToSql for StoredUuid { } /// An on-disk storage backend which uses SQLite. +/// +/// A new connection is opened for each transaction, and only one transaction may be active at a +/// time; a second call to `txn` will block until the first transaction is dropped. pub struct SqliteStorage { db_file: std::path::PathBuf, } @@ -54,11 +51,13 @@ impl SqliteStorage { let o = SqliteStorage { db_file }; - { - let mut con = o.new_connection()?; - let txn = con.transaction()?; + let con = o.new_connection()?; - let queries = vec![ + // Use the modern WAL mode. + con.query_row("PRAGMA journal_mode=WAL", [], |_row| Ok(())) + .context("Setting journal_mode=WAL")?; + + let queries = vec![ "CREATE TABLE IF NOT EXISTS clients ( client_id STRING PRIMARY KEY, latest_version_id STRING, @@ -69,11 +68,9 @@ impl SqliteStorage { "CREATE TABLE IF NOT EXISTS versions (version_id STRING PRIMARY KEY, client_id STRING, parent_version_id STRING, history_segment BLOB);", "CREATE INDEX IF NOT EXISTS versions_by_parent ON versions (parent_version_id);", ]; - for q in queries { - txn.execute(q, []) - .context("Error while creating SQLite tables")?; - } - txn.commit()?; + for q in queries { + con.execute(q, []) + .context("Error while creating SQLite tables")?; } Ok(o) @@ -83,22 +80,22 @@ impl SqliteStorage { impl Storage for SqliteStorage { fn txn<'a>(&'a self) -> anyhow::Result> { let con = self.new_connection()?; - let t = Txn { con }; - Ok(Box::new(t)) + // Begin the transaction on this new connection. An IMMEDIATE connection is in + // write (exclusive) mode from the start. + con.execute("BEGIN IMMEDIATE", [])?; + let txn = Txn { con }; + Ok(Box::new(txn)) } } struct Txn { + // SQLite only allows one concurrent transaction per connection, and rusqlite emulates + // transactions by running `BEGIN ...` and `COMMIT` at appropriate times. So we will do + // the same. con: Connection, } impl Txn { - fn get_txn(&mut self) -> Result { - self.con - .transaction() - .map_err(|_e| SqliteError::CreateTransactionFailed) - } - /// Implementation for queries from the versions table fn get_version_impl( &mut self, @@ -106,8 +103,8 @@ impl Txn { client_id: Uuid, version_id_arg: Uuid, ) -> anyhow::Result> { - let t = self.get_txn()?; - let r = t + let r = self + .con .query_row( query, params![&StoredUuid(version_id_arg), &StoredUuid(client_id)], @@ -130,8 +127,8 @@ impl Txn { impl StorageTxn for Txn { fn get_client(&mut self, client_id: Uuid) -> anyhow::Result> { - let t = self.get_txn()?; - let result: Option = t + let result: Option = self + .con .query_row( "SELECT latest_version_id, @@ -174,14 +171,12 @@ impl StorageTxn for Txn { } fn new_client(&mut self, client_id: Uuid, latest_version_id: Uuid) -> anyhow::Result<()> { - let t = self.get_txn()?; - - t.execute( - "INSERT OR REPLACE INTO clients (client_id, latest_version_id) VALUES (?, ?)", - params![&StoredUuid(client_id), &StoredUuid(latest_version_id)], - ) - .context("Error creating/updating client")?; - t.commit()?; + self.con + .execute( + "INSERT OR REPLACE INTO clients (client_id, latest_version_id) VALUES (?, ?)", + params![&StoredUuid(client_id), &StoredUuid(latest_version_id)], + ) + .context("Error creating/updating client")?; Ok(()) } @@ -191,26 +186,24 @@ impl StorageTxn for Txn { snapshot: Snapshot, data: Vec, ) -> anyhow::Result<()> { - let t = self.get_txn()?; - - t.execute( - "UPDATE clients + self.con + .execute( + "UPDATE clients SET snapshot_version_id = ?, snapshot_timestamp = ?, versions_since_snapshot = ?, snapshot = ? WHERE client_id = ?", - params![ - &StoredUuid(snapshot.version_id), - snapshot.timestamp.timestamp(), - snapshot.versions_since, - data, - &StoredUuid(client_id), - ], - ) - .context("Error creating/updating snapshot")?; - t.commit()?; + params![ + &StoredUuid(snapshot.version_id), + snapshot.timestamp.timestamp(), + snapshot.versions_since, + data, + &StoredUuid(client_id), + ], + ) + .context("Error creating/updating snapshot")?; Ok(()) } @@ -219,8 +212,8 @@ impl StorageTxn for Txn { client_id: Uuid, version_id: Uuid, ) -> anyhow::Result>> { - let t = self.get_txn()?; - let r = t + let r = self + .con .query_row( "SELECT snapshot, snapshot_version_id FROM clients WHERE client_id = ?", params![&StoredUuid(client_id)], @@ -271,9 +264,7 @@ impl StorageTxn for Txn { parent_version_id: Uuid, history_segment: Vec, ) -> anyhow::Result<()> { - let t = self.get_txn()?; - - t.execute( + self.con.execute( "INSERT INTO versions (version_id, client_id, parent_version_id, history_segment) VALUES(?, ?, ?, ?)", params![ StoredUuid(version_id), @@ -283,25 +274,22 @@ impl StorageTxn for Txn { ] ) .context("Error adding version")?; - t.execute( - "UPDATE clients + self.con + .execute( + "UPDATE clients SET latest_version_id = ?, versions_since_snapshot = versions_since_snapshot + 1 WHERE client_id = ?", - params![StoredUuid(version_id), StoredUuid(client_id),], - ) - .context("Error updating client for new version")?; + params![StoredUuid(version_id), StoredUuid(client_id),], + ) + .context("Error updating client for new version")?; - t.commit()?; Ok(()) } fn commit(&mut self) -> anyhow::Result<()> { - // FIXME: Note the queries aren't currently run in a - // transaction, as storing the transaction object and a pooled - // connection in the `Txn` object is complex. - // https://github.com/taskchampion/taskchampion/pull/206#issuecomment-860336073 + self.con.execute("COMMIT", [])?; Ok(()) } } diff --git a/sqlite/tests/concurrency.rs b/sqlite/tests/concurrency.rs new file mode 100644 index 0000000..17c729e --- /dev/null +++ b/sqlite/tests/concurrency.rs @@ -0,0 +1,67 @@ +use std::thread; +use taskchampion_sync_server_core::{Storage, NIL_VERSION_ID}; +use taskchampion_sync_server_storage_sqlite::SqliteStorage; +use tempfile::TempDir; +use uuid::Uuid; + +#[test] +fn add_version_concurrency() -> anyhow::Result<()> { + let tmp_dir = TempDir::new()?; + let client_id = Uuid::new_v4(); + + { + let con = SqliteStorage::new(tmp_dir.path())?; + let mut txn = con.txn()?; + txn.new_client(client_id, NIL_VERSION_ID)?; + txn.commit()?; + } + + const N: i32 = 100; + const T: i32 = 4; + + // Add N versions to the DB. + let add_versions = || { + let con = SqliteStorage::new(tmp_dir.path())?; + + for _ in 0..N { + let mut txn = con.txn()?; + let client = txn.get_client(client_id)?.unwrap(); + let version_id = Uuid::new_v4(); + let parent_version_id = client.latest_version_id; + std::thread::yield_now(); // Make failure more likely. + txn.add_version(client_id, version_id, parent_version_id, b"data".to_vec())?; + txn.commit()?; + } + + Ok::<_, anyhow::Error>(()) + }; + + thread::scope(|s| { + // Spawn T threads. + for _ in 0..T { + s.spawn(add_versions); + } + }); + + // There should now be precisely N*T versions. This number will be smaller if there was a + // concurrency error allowing two conflicting `add_version` calls to overlap. + { + let con = SqliteStorage::new(tmp_dir.path())?; + let mut txn = con.txn()?; + let client = txn.get_client(client_id)?.unwrap(); + + let mut n = 0; + let mut version_id = client.latest_version_id; + while version_id != NIL_VERSION_ID { + let version = txn + .get_version(client_id, version_id)? + .expect("version should exist"); + n += 1; + version_id = version.parent_version_id; + } + + assert_eq!(n, N * T); + } + + Ok(()) +}