Skip to content

Commit

Permalink
Add display methods for string pointer types (#1904)
Browse files Browse the repository at this point in the history
  • Loading branch information
rylev authored Jul 13, 2022
1 parent d734f1d commit 8c093b4
Show file tree
Hide file tree
Showing 7 changed files with 124 additions and 18 deletions.
14 changes: 7 additions & 7 deletions crates/libs/windows/src/core/strings/literals.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ macro_rules! w {
let mut buffer = [0; OUTPUT_LEN];
let mut input_pos = 0;
let mut output_pos = 0;
while let Some((mut code_point, new_pos)) = ::windows::core::decode_utf8(INPUT, input_pos) {
while let Some((mut code_point, new_pos)) = ::windows::core::decode_utf8_char(INPUT, input_pos) {
input_pos = new_pos;
if code_point <= 0xffff {
buffer[output_pos] = code_point as u16;
Expand Down Expand Up @@ -53,7 +53,7 @@ pub use s;
pub use w;

#[doc(hidden)]
pub const fn decode_utf8(bytes: &[u8], mut pos: usize) -> Option<(u32, usize)> {
pub const fn decode_utf8_char(bytes: &[u8], mut pos: usize) -> Option<(u32, usize)> {
if bytes.len() == pos {
return None;
}
Expand Down Expand Up @@ -130,7 +130,7 @@ pub struct HSTRING_HEADER {
pub const fn utf16_len(bytes: &[u8]) -> usize {
let mut pos = 0;
let mut len = 0;
while let Some((code_point, new_pos)) = decode_utf8(bytes, pos) {
while let Some((code_point, new_pos)) = decode_utf8_char(bytes, pos) {
pos = new_pos;
len += if code_point <= 0xffff { 1 } else { 2 };
}
Expand All @@ -143,10 +143,10 @@ mod tests {

#[test]
fn test() {
assert_eq!(decode_utf8(b"123", 0), Some((0x31, 1)));
assert_eq!(decode_utf8(b"123", 1), Some((0x32, 2)));
assert_eq!(decode_utf8(b"123", 2), Some((0x33, 3)));
assert_eq!(decode_utf8(b"123", 3), None);
assert_eq!(decode_utf8_char(b"123", 0), Some((0x31, 1)));
assert_eq!(decode_utf8_char(b"123", 1), Some((0x32, 2)));
assert_eq!(decode_utf8_char(b"123", 2), Some((0x33, 3)));
assert_eq!(decode_utf8_char(b"123", 3), None);
assert_eq!(utf16_len(b"123"), 3);
assert_eq!(utf16_len("α & ω".as_bytes()), 5);
}
Expand Down
57 changes: 57 additions & 0 deletions crates/libs/windows/src/core/strings/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,60 @@ extern "C" {
pub fn strlen(s: PCSTR) -> usize;
pub fn wcslen(s: PCWSTR) -> usize;
}

/// An internal helper for decoding an iterator of chars and displaying them
struct Decode<F>(F);

impl<F, R, E> core::fmt::Display for Decode<F>
where
F: Clone + FnOnce() -> R,
R: IntoIterator<Item = core::result::Result<char, E>>,
{
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
use core::fmt::Write;
let iter = self.0.clone();
for c in iter().into_iter() {
f.write_char(c.unwrap_or_else(|_| std::char::REPLACEMENT_CHARACTER))?
}
Ok(())
}
}

/// Mirror of `std::char::decode_utf16` for utf-8.
fn decode_utf8<'a>(mut buffer: &'a [u8]) -> impl Iterator<Item = core::result::Result<char, std::str::Utf8Error>> + 'a {
let mut current = "".chars();
let mut previous_error = None;
std::iter::from_fn(move || {
loop {
match (current.next(), previous_error) {
(Some(c), _) => return Some(Ok(c)),
// Return the previous error
(None, Some(e)) => {
previous_error = None;
return Some(Err(e));
}
// We're completely done
(None, None) if buffer.is_empty() => return None,
(None, None) => {
match std::str::from_utf8(buffer) {
Ok(s) => {
current = s.chars();
buffer = &[];
}
Err(e) => {
let (valid, rest) = buffer.split_at(e.valid_up_to());
// Skip the invalid sequence and stop completely if we ended early
let invalid_sequence_length = e.error_len()?;
buffer = &rest[invalid_sequence_length..];

// Set the current iterator to the valid section and indicate previous error
// SAFETY: `valid` is known to be valid utf-8 from error
current = unsafe { std::str::from_utf8_unchecked(valid) }.chars();
previous_error = Some(e);
}
}
}
}
}
})
}
22 changes: 22 additions & 0 deletions crates/libs/windows/src/core/strings/pcstr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,15 @@ impl PCSTR {
pub unsafe fn to_string(&self) -> core::result::Result<String, std::string::FromUtf8Error> {
String::from_utf8(self.as_bytes().into())
}

/// Allow this string to be displayed.
///
/// # Safety
///
/// See the safety information for `PCSTR::as_bytes`.
pub unsafe fn display<'a>(&'a self) -> impl core::fmt::Display + 'a {
Decode(move || decode_utf8(self.as_bytes()))
}
}

unsafe impl Abi for PCSTR {
Expand All @@ -57,3 +66,16 @@ impl From<Option<PCSTR>> for PCSTR {
from.unwrap_or_else(Self::null)
}
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn can_display() {
// 💖 followed by an invalid byte sequence and then an incomplete one
let s = vec![240, 159, 146, 150, 255, 240, 159, 0];
let s = PCSTR::from_raw(s.as_ptr());
assert_eq!("💖�", format!("{}", unsafe { s.display() }));
}
}
9 changes: 9 additions & 0 deletions crates/libs/windows/src/core/strings/pcwstr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,15 @@ impl PCWSTR {
pub unsafe fn to_string(&self) -> core::result::Result<String, std::string::FromUtf16Error> {
String::from_utf16(self.as_wide())
}

/// Allow this string to be displayed.
///
/// # Safety
///
/// See the safety information for `PCWSTR::as_wide`.
pub unsafe fn display<'a>(&'a self) -> impl core::fmt::Display + 'a {
Decode(move || core::char::decode_utf16(self.as_wide().iter().cloned()))
}
}

unsafe impl Abi for PCWSTR {
Expand Down
9 changes: 9 additions & 0 deletions crates/libs/windows/src/core/strings/pstr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,15 @@ impl PSTR {
pub unsafe fn to_string(&self) -> core::result::Result<String, std::string::FromUtf8Error> {
String::from_utf8(self.as_bytes().into())
}

/// Allow this string to be displayed.
///
/// # Safety
///
/// See the safety information for `PSTR::as_bytes`.
pub unsafe fn display<'a>(&'a self) -> impl core::fmt::Display + 'a {
Decode(move || decode_utf8(self.as_bytes()))
}
}

unsafe impl Abi for PSTR {
Expand Down
21 changes: 15 additions & 6 deletions crates/libs/windows/src/core/strings/pwstr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,27 +6,27 @@ use super::*;
pub struct PWSTR(pub *mut u16);

impl PWSTR {
/// Construct a new `PWSTR` from a raw pointer
/// Construct a new `PWSTR` from a raw pointer.
pub const fn from_raw(ptr: *mut u16) -> Self {
Self(ptr)
}

/// Construct a null `PWSTR`
/// Construct a null `PWSTR`.
pub fn null() -> Self {
Self(core::ptr::null_mut())
}

/// Returns a raw pointer to the `PWSTR`
/// Returns a raw pointer to the `PWSTR`.
pub fn as_ptr(&self) -> *mut u16 {
self.0
}

/// Checks whether the `PWSTR` is null
/// Checks whether the `PWSTR` is null.
pub fn is_null(&self) -> bool {
self.0.is_null()
}

/// String data without the trailing 0
/// String data without the trailing 0.
///
/// # Safety
///
Expand All @@ -40,10 +40,19 @@ impl PWSTR {
///
/// # Safety
///
/// See the safety information for `PWSTR::as_bytes`.
/// See the safety information for `PWSTR::as_wide`.
pub unsafe fn to_string(&self) -> core::result::Result<String, std::string::FromUtf16Error> {
String::from_utf16(self.as_wide().into())
}

/// Allow this string to be displayed.
///
/// # Safety
///
/// See the safety information for `PWSTR::as_wide`.
pub unsafe fn display<'a>(&'a self) -> impl core::fmt::Display + 'a {
Decode(move || core::char::decode_utf16(self.as_wide().iter().cloned()))
}
}

unsafe impl Abi for PWSTR {
Expand Down
10 changes: 5 additions & 5 deletions crates/samples/spellchecker/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,9 @@ fn main() -> Result<()> {
// Get the replacement as a widestring and convert to a Rust String
let replacement = unsafe { error.Replacement()? };

println!("Replace: {} with {}", substring, unsafe { replacement.to_string().unwrap() });
println!("Replace: {} with {}", substring, unsafe { replacement.display() });

unsafe { CoTaskMemFree(replacement.0 as *mut _) };
unsafe { CoTaskMemFree(replacement.as_ptr() as *mut _) };
}
CORRECTIVE_ACTION_GET_SUGGESTIONS => {
// Get an enumerator for all the suggestions for a substring
Expand All @@ -58,13 +58,13 @@ fn main() -> Result<()> {
unsafe {
let _ = suggestions.Next(&mut suggestion, std::ptr::null_mut());
}
if suggestion[0].0.is_null() {
if suggestion[0].is_null() {
break;
}

println!("Maybe replace: {} with {}", substring, unsafe { suggestion[0].to_string().unwrap() });
println!("Maybe replace: {} with {}", substring, unsafe { suggestion[0].display() });

unsafe { CoTaskMemFree(suggestion[0].0 as *mut _) };
unsafe { CoTaskMemFree(suggestion[0].as_ptr() as *mut _) };
}
}
_ => {}
Expand Down

0 comments on commit 8c093b4

Please sign in to comment.