diff --git a/crates/libs/windows/src/core/strings/literals.rs b/crates/libs/windows/src/core/strings/literals.rs index 147f6a8da5..56458cbe12 100644 --- a/crates/libs/windows/src/core/strings/literals.rs +++ b/crates/libs/windows/src/core/strings/literals.rs @@ -24,7 +24,7 @@ macro_rules! w { let mut buffer = [0; OUTPUT_LEN]; let mut input_pos = 0; let mut output_pos = 0; - while let Some((mut code_point, new_pos)) = ::windows::core::decode_utf8(INPUT, input_pos) { + while let Some((mut code_point, new_pos)) = ::windows::core::decode_utf8_char(INPUT, input_pos) { input_pos = new_pos; if code_point <= 0xffff { buffer[output_pos] = code_point as u16; @@ -53,7 +53,7 @@ pub use s; pub use w; #[doc(hidden)] -pub const fn decode_utf8(bytes: &[u8], mut pos: usize) -> Option<(u32, usize)> { +pub const fn decode_utf8_char(bytes: &[u8], mut pos: usize) -> Option<(u32, usize)> { if bytes.len() == pos { return None; } @@ -130,7 +130,7 @@ pub struct HSTRING_HEADER { pub const fn utf16_len(bytes: &[u8]) -> usize { let mut pos = 0; let mut len = 0; - while let Some((code_point, new_pos)) = decode_utf8(bytes, pos) { + while let Some((code_point, new_pos)) = decode_utf8_char(bytes, pos) { pos = new_pos; len += if code_point <= 0xffff { 1 } else { 2 }; } @@ -143,10 +143,10 @@ mod tests { #[test] fn test() { - assert_eq!(decode_utf8(b"123", 0), Some((0x31, 1))); - assert_eq!(decode_utf8(b"123", 1), Some((0x32, 2))); - assert_eq!(decode_utf8(b"123", 2), Some((0x33, 3))); - assert_eq!(decode_utf8(b"123", 3), None); + assert_eq!(decode_utf8_char(b"123", 0), Some((0x31, 1))); + assert_eq!(decode_utf8_char(b"123", 1), Some((0x32, 2))); + assert_eq!(decode_utf8_char(b"123", 2), Some((0x33, 3))); + assert_eq!(decode_utf8_char(b"123", 3), None); assert_eq!(utf16_len(b"123"), 3); assert_eq!(utf16_len("α & ω".as_bytes()), 5); } diff --git a/crates/libs/windows/src/core/strings/mod.rs b/crates/libs/windows/src/core/strings/mod.rs index 886234a135..df4a03a20c 100644 --- a/crates/libs/windows/src/core/strings/mod.rs +++ b/crates/libs/windows/src/core/strings/mod.rs @@ -19,3 +19,60 @@ extern "C" { pub fn strlen(s: PCSTR) -> usize; pub fn wcslen(s: PCWSTR) -> usize; } + +/// An internal helper for decoding an iterator of chars and displaying them +struct Decode(F); + +impl core::fmt::Display for Decode +where + F: Clone + FnOnce() -> R, + R: IntoIterator>, +{ + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + use core::fmt::Write; + let iter = self.0.clone(); + for c in iter().into_iter() { + f.write_char(c.unwrap_or_else(|_| std::char::REPLACEMENT_CHARACTER))? + } + Ok(()) + } +} + +/// Mirror of `std::char::decode_utf16` for utf-8. +fn decode_utf8<'a>(mut buffer: &'a [u8]) -> impl Iterator> + 'a { + let mut current = "".chars(); + let mut previous_error = None; + std::iter::from_fn(move || { + loop { + match (current.next(), previous_error) { + (Some(c), _) => return Some(Ok(c)), + // Return the previous error + (None, Some(e)) => { + previous_error = None; + return Some(Err(e)); + } + // We're completely done + (None, None) if buffer.is_empty() => return None, + (None, None) => { + match std::str::from_utf8(buffer) { + Ok(s) => { + current = s.chars(); + buffer = &[]; + } + Err(e) => { + let (valid, rest) = buffer.split_at(e.valid_up_to()); + // Skip the invalid sequence and stop completely if we ended early + let invalid_sequence_length = e.error_len()?; + buffer = &rest[invalid_sequence_length..]; + + // Set the current iterator to the valid section and indicate previous error + // SAFETY: `valid` is known to be valid utf-8 from error + current = unsafe { std::str::from_utf8_unchecked(valid) }.chars(); + previous_error = Some(e); + } + } + } + } + } + }) +} diff --git a/crates/libs/windows/src/core/strings/pcstr.rs b/crates/libs/windows/src/core/strings/pcstr.rs index e24aa3948e..a908f04692 100644 --- a/crates/libs/windows/src/core/strings/pcstr.rs +++ b/crates/libs/windows/src/core/strings/pcstr.rs @@ -44,6 +44,15 @@ impl PCSTR { pub unsafe fn to_string(&self) -> core::result::Result { String::from_utf8(self.as_bytes().into()) } + + /// Allow this string to be displayed. + /// + /// # Safety + /// + /// See the safety information for `PCSTR::as_bytes`. + pub unsafe fn display<'a>(&'a self) -> impl core::fmt::Display + 'a { + Decode(move || decode_utf8(self.as_bytes())) + } } unsafe impl Abi for PCSTR { @@ -57,3 +66,16 @@ impl From> for PCSTR { from.unwrap_or_else(Self::null) } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn can_display() { + // 💖 followed by an invalid byte sequence and then an incomplete one + let s = vec![240, 159, 146, 150, 255, 240, 159, 0]; + let s = PCSTR::from_raw(s.as_ptr()); + assert_eq!("💖�", format!("{}", unsafe { s.display() })); + } +} diff --git a/crates/libs/windows/src/core/strings/pcwstr.rs b/crates/libs/windows/src/core/strings/pcwstr.rs index 7ac51cef5c..b9c03a694c 100644 --- a/crates/libs/windows/src/core/strings/pcwstr.rs +++ b/crates/libs/windows/src/core/strings/pcwstr.rs @@ -44,6 +44,15 @@ impl PCWSTR { pub unsafe fn to_string(&self) -> core::result::Result { String::from_utf16(self.as_wide()) } + + /// Allow this string to be displayed. + /// + /// # Safety + /// + /// See the safety information for `PCWSTR::as_wide`. + pub unsafe fn display<'a>(&'a self) -> impl core::fmt::Display + 'a { + Decode(move || core::char::decode_utf16(self.as_wide().iter().cloned())) + } } unsafe impl Abi for PCWSTR { diff --git a/crates/libs/windows/src/core/strings/pstr.rs b/crates/libs/windows/src/core/strings/pstr.rs index 0d904f7042..9283022825 100644 --- a/crates/libs/windows/src/core/strings/pstr.rs +++ b/crates/libs/windows/src/core/strings/pstr.rs @@ -44,6 +44,15 @@ impl PSTR { pub unsafe fn to_string(&self) -> core::result::Result { String::from_utf8(self.as_bytes().into()) } + + /// Allow this string to be displayed. + /// + /// # Safety + /// + /// See the safety information for `PSTR::as_bytes`. + pub unsafe fn display<'a>(&'a self) -> impl core::fmt::Display + 'a { + Decode(move || decode_utf8(self.as_bytes())) + } } unsafe impl Abi for PSTR { diff --git a/crates/libs/windows/src/core/strings/pwstr.rs b/crates/libs/windows/src/core/strings/pwstr.rs index 90713bae28..e2a299a0b3 100644 --- a/crates/libs/windows/src/core/strings/pwstr.rs +++ b/crates/libs/windows/src/core/strings/pwstr.rs @@ -6,27 +6,27 @@ use super::*; pub struct PWSTR(pub *mut u16); impl PWSTR { - /// Construct a new `PWSTR` from a raw pointer + /// Construct a new `PWSTR` from a raw pointer. pub const fn from_raw(ptr: *mut u16) -> Self { Self(ptr) } - /// Construct a null `PWSTR` + /// Construct a null `PWSTR`. pub fn null() -> Self { Self(core::ptr::null_mut()) } - /// Returns a raw pointer to the `PWSTR` + /// Returns a raw pointer to the `PWSTR`. pub fn as_ptr(&self) -> *mut u16 { self.0 } - /// Checks whether the `PWSTR` is null + /// Checks whether the `PWSTR` is null. pub fn is_null(&self) -> bool { self.0.is_null() } - /// String data without the trailing 0 + /// String data without the trailing 0. /// /// # Safety /// @@ -40,10 +40,19 @@ impl PWSTR { /// /// # Safety /// - /// See the safety information for `PWSTR::as_bytes`. + /// See the safety information for `PWSTR::as_wide`. pub unsafe fn to_string(&self) -> core::result::Result { String::from_utf16(self.as_wide().into()) } + + /// Allow this string to be displayed. + /// + /// # Safety + /// + /// See the safety information for `PWSTR::as_wide`. + pub unsafe fn display<'a>(&'a self) -> impl core::fmt::Display + 'a { + Decode(move || core::char::decode_utf16(self.as_wide().iter().cloned())) + } } unsafe impl Abi for PWSTR { diff --git a/crates/samples/spellchecker/src/main.rs b/crates/samples/spellchecker/src/main.rs index c913c8f227..c1e4945046 100644 --- a/crates/samples/spellchecker/src/main.rs +++ b/crates/samples/spellchecker/src/main.rs @@ -43,9 +43,9 @@ fn main() -> Result<()> { // Get the replacement as a widestring and convert to a Rust String let replacement = unsafe { error.Replacement()? }; - println!("Replace: {} with {}", substring, unsafe { replacement.to_string().unwrap() }); + println!("Replace: {} with {}", substring, unsafe { replacement.display() }); - unsafe { CoTaskMemFree(replacement.0 as *mut _) }; + unsafe { CoTaskMemFree(replacement.as_ptr() as *mut _) }; } CORRECTIVE_ACTION_GET_SUGGESTIONS => { // Get an enumerator for all the suggestions for a substring @@ -58,13 +58,13 @@ fn main() -> Result<()> { unsafe { let _ = suggestions.Next(&mut suggestion, std::ptr::null_mut()); } - if suggestion[0].0.is_null() { + if suggestion[0].is_null() { break; } - println!("Maybe replace: {} with {}", substring, unsafe { suggestion[0].to_string().unwrap() }); + println!("Maybe replace: {} with {}", substring, unsafe { suggestion[0].display() }); - unsafe { CoTaskMemFree(suggestion[0].0 as *mut _) }; + unsafe { CoTaskMemFree(suggestion[0].as_ptr() as *mut _) }; } } _ => {}