From dadc9473d3b1d419a5aa531b775b46b834af5fd3 Mon Sep 17 00:00:00 2001 From: "Dustin J. Mitchell" Date: Sun, 6 Feb 2022 16:21:42 +0000 Subject: [PATCH] unit tests for TCString --- Cargo.lock | 1 + lib/Cargo.toml | 3 + lib/src/string.rs | 181 ++++++++++++++++++++++++++++++++++++++-------- 3 files changed, 155 insertions(+), 30 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index ad2920754..ca9160c27 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3038,6 +3038,7 @@ dependencies = [ "cbindgen", "chrono", "libc", + "pretty_assertions", "taskchampion", "uuid", ] diff --git a/lib/Cargo.toml b/lib/Cargo.toml index 66b8c6051..86a2ec735 100644 --- a/lib/Cargo.toml +++ b/lib/Cargo.toml @@ -15,5 +15,8 @@ taskchampion = { path = "../taskchampion" } uuid = { version = "^0.8.2", features = ["v4"] } anyhow = "1.0" +[dev-dependencies] +pretty_assertions = "1" + [build-dependencies] cbindgen = "0.20.0" diff --git a/lib/src/string.rs b/lib/src/string.rs index d3f952189..08eed500c 100644 --- a/lib/src/string.rs +++ b/lib/src/string.rs @@ -34,6 +34,7 @@ use std::str::Utf8Error; /// must not use or free TCStrings after passing them to such API functions. /// /// TCStrings are not threadsafe. +#[derive(PartialEq, Debug)] pub enum TCString<'a> { CString(CString), CStr(&'a CStr), @@ -78,6 +79,32 @@ impl<'a> TCString<'a> { } } + /// Convert the TCString, in place, into one of the C variants. If this is not + /// possible, such as if the string contains an embedded NUL, then the string + /// remains unchanged. + fn to_c_string(&mut self) { + if matches!(self, TCString::String(_)) { + // we must take ownership of the String in order to try converting it, + // leaving the underlying TCString as its default (None) + if let TCString::String(string) = std::mem::take(self) { + match CString::new(string) { + Ok(cstring) => *self = TCString::CString(cstring), + Err(nul_err) => { + // recover the underlying String from the NulError and restore + // the TCString + let original_bytes = nul_err.into_vec(); + // SAFETY: original_bytes came from a String moments ago, so still valid utf8 + let string = unsafe { String::from_utf8_unchecked(original_bytes) }; + *self = TCString::String(string); + } + } + } else { + // the `matches!` above verified self was a TCString::String + unreachable!() + } + } + } + pub(crate) fn to_path_buf(&self) -> PathBuf { // TODO: this is UNIX-specific. let path: &OsStr = OsStr::from_bytes(self.as_bytes()); @@ -158,7 +185,10 @@ pub extern "C" fn tc_string_clone_with_len( // does not outlive this function call) // - the length of the buffer is less than isize::MAX (promised by caller) let slice = unsafe { std::slice::from_raw_parts(buf as *const u8, len) }; + + // allocate and copy into Rust-controlled memory let vec = slice.to_vec(); + // try converting to a string, which is the only variant that can contain embedded NULs. If // the bytes are not valid utf-8, store that information for reporting later. let tcstring = match String::from_utf8(vec) { @@ -168,6 +198,7 @@ pub extern "C" fn tc_string_clone_with_len( TCString::InvalidUtf8(e, vec) } }; + // SAFETY: see docstring unsafe { tcstring.return_val() } } @@ -190,32 +221,11 @@ pub extern "C" fn tc_string_content(tcstring: *mut TCString) -> *const libc::c_c // if we have a String, we need to consume it and turn it into // a CString. - if matches!(tcstring, TCString::String(_)) { - // TODO: put this in a method - if let TCString::String(string) = std::mem::take(tcstring) { - match CString::new(string) { - Ok(cstring) => { - *tcstring = TCString::CString(cstring); - } - Err(nul_err) => { - // recover the underlying String from the NulError - let original_bytes = nul_err.into_vec(); - // SAFETY: original_bytes came from a String moments ago, so still valid utf8 - let string = unsafe { String::from_utf8_unchecked(original_bytes) }; - *tcstring = TCString::String(string); - - // and return NULL as advertized - return std::ptr::null(); - } - } - } else { - unreachable!() - } - } + tcstring.to_c_string(); match tcstring { TCString::CString(cstring) => cstring.as_ptr(), - TCString::String(_) => unreachable!(), // just converted to CString + TCString::String(_) => std::ptr::null(), // to_c_string failed TCString::CStr(cstr) => cstr.as_ptr(), TCString::InvalidUtf8(_, _) => std::ptr::null(), TCString::None => unreachable!(), @@ -240,13 +250,8 @@ pub extern "C" fn tc_string_content_with_len( let tcstring = unsafe { TCString::from_arg_ref(tcstring) }; debug_assert!(!len_out.is_null()); - let bytes = match tcstring { - TCString::CString(cstring) => cstring.as_bytes(), - TCString::String(string) => string.as_bytes(), - TCString::CStr(cstr) => cstr.to_bytes(), - TCString::InvalidUtf8(_, ref v) => v.as_ref(), - TCString::None => unreachable!(), - }; + let bytes = tcstring.as_bytes(); + // SAFETY: // - len_out is not NULL (checked by assertion, promised by caller) // - len_out points to valid memory (promised by caller) @@ -264,3 +269,119 @@ pub extern "C" fn tc_string_free(tcstring: *mut TCString) { // - caller is exclusive owner of tcstring (promised by caller) drop(unsafe { TCString::take_from_arg(tcstring) }); } + +#[cfg(test)] +mod test { + use super::*; + use pretty_assertions::assert_eq; + + const INVALID_UTF8: &[u8] = b"abc\xf0\x28\x8c\x28"; + + fn make_cstring() -> TCString<'static> { + TCString::CString(CString::new("a string").unwrap()) + } + + fn make_cstr() -> TCString<'static> { + let cstr = CStr::from_bytes_with_nul(b"a string\0").unwrap(); + TCString::CStr(&cstr) + } + + fn make_string() -> TCString<'static> { + TCString::String("a string".into()) + } + + fn make_string_with_nul() -> TCString<'static> { + TCString::String("a \0 nul!".into()) + } + + fn make_invalid() -> TCString<'static> { + let e = String::from_utf8(INVALID_UTF8.to_vec()).unwrap_err(); + TCString::InvalidUtf8(e.utf8_error(), e.into_bytes()) + } + + #[test] + fn cstring_as_str() { + assert_eq!(make_cstring().as_str().unwrap(), "a string"); + } + + #[test] + fn cstr_as_str() { + assert_eq!(make_cstr().as_str().unwrap(), "a string"); + } + + #[test] + fn string_as_str() { + assert_eq!(make_string().as_str().unwrap(), "a string"); + } + + #[test] + fn string_with_nul_as_str() { + assert_eq!(make_string_with_nul().as_str().unwrap(), "a \0 nul!"); + } + + #[test] + fn invalid_as_str() { + let as_str_err = make_invalid().as_str().unwrap_err(); + assert_eq!(as_str_err.valid_up_to(), 3); // "abc" is valid + } + + #[test] + fn cstring_as_bytes() { + assert_eq!(make_cstring().as_bytes(), b"a string"); + } + + #[test] + fn cstr_as_bytes() { + assert_eq!(make_cstr().as_bytes(), b"a string"); + } + + #[test] + fn string_as_bytes() { + assert_eq!(make_string().as_bytes(), b"a string"); + } + + #[test] + fn string_with_nul_as_bytes() { + assert_eq!(make_string_with_nul().as_bytes(), b"a \0 nul!"); + } + + #[test] + fn invalid_as_bytes() { + assert_eq!(make_invalid().as_bytes(), INVALID_UTF8); + } + + #[test] + fn cstring_to_c_string() { + let mut tcstring = make_cstring(); + tcstring.to_c_string(); + assert_eq!(tcstring, make_cstring()); // unchanged + } + + #[test] + fn cstr_to_c_string() { + let mut tcstring = make_cstr(); + tcstring.to_c_string(); + assert_eq!(tcstring, make_cstr()); // unchanged + } + + #[test] + fn string_to_c_string() { + let mut tcstring = make_string(); + tcstring.to_c_string(); + assert_eq!(tcstring, make_cstring()); // converted to CString, same content + } + + #[test] + fn string_with_nul_to_c_string() { + let mut tcstring = make_string_with_nul(); + tcstring.to_c_string(); + assert_eq!(tcstring, make_string_with_nul()); // unchanged + } + + #[test] + fn invalid_to_c_string() { + let mut tcstring = make_invalid(); + tcstring.to_c_string(); + assert_eq!(tcstring, make_invalid()); // unchanged + } +}