Skip to content

Commit

Permalink
Merge pull request #67 from felstead/trace-pid
Browse files Browse the repository at this point in the history
Add the ability to trace processes by process ID
  • Loading branch information
nico-abram committed May 2, 2024
2 parents 6b3bac8 + cb98272 commit daacd76
Showing 1 changed file with 51 additions and 41 deletions.
92 changes: 51 additions & 41 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,15 @@
//! You can use [`trace_command`] to execute and sample an [`std::process::Command`].
//!
//! Or you can use [`trace_child`] to start tracing an [`std::process::Child`].
// You can also trace an arbitrary process using [`trace_pid`].
//! You can also trace an arbitrary process using [`trace_pid`].

#![allow(clippy::field_reassign_with_default)]

use object::Object;
use windows::core::{GUID, PCSTR, PSTR};
use windows::Win32::Foundation::{
CloseHandle, DuplicateHandle, GetLastError, DUPLICATE_SAME_ACCESS, ERROR_SUCCESS,
ERROR_WMI_INSTANCE_NOT_FOUND, HANDLE, INVALID_HANDLE_VALUE, WIN32_ERROR,
CloseHandle, GetLastError, ERROR_SUCCESS, ERROR_WMI_INSTANCE_NOT_FOUND, HANDLE,
INVALID_HANDLE_VALUE, WIN32_ERROR,
};
use windows::Win32::Security::{
AdjustTokenPrivileges, LookupPrivilegeValueW, SE_PRIVILEGE_ENABLED, TOKEN_ADJUST_PRIVILEGES,
Expand All @@ -32,16 +32,16 @@ use windows::Win32::System::Diagnostics::Etw::{
use windows::Win32::System::SystemInformation::{GetVersionExA, OSVERSIONINFOA};
use windows::Win32::System::SystemServices::SE_SYSTEM_PROFILE_NAME;
use windows::Win32::System::Threading::{
GetCurrentProcess, GetCurrentThread, OpenProcessToken, SetThreadPriority, CREATE_SUSPENDED,
THREAD_PRIORITY_TIME_CRITICAL,
GetCurrentProcess, GetCurrentThread, OpenProcess, OpenProcessToken, SetThreadPriority,
WaitForSingleObject, CREATE_SUSPENDED, PROCESS_ALL_ACCESS, THREAD_PRIORITY_TIME_CRITICAL,
};

use pdb_addr2line::{pdb::PDB, ContextPdbData};

use std::ffi::OsString;
use std::io::{Read, Write};
use std::mem::size_of;
use std::os::windows::{ffi::OsStringExt, prelude::AsRawHandle};
use std::os::windows::ffi::OsStringExt;
use std::path::PathBuf;
use std::ptr::{addr_of, addr_of_mut};
use std::sync::atomic::{AtomicBool, Ordering};
Expand Down Expand Up @@ -106,8 +106,10 @@ pub enum Error {
Write(std::io::Error),
/// Error spawning a suspended process
SpawnErr(std::io::Error),
/// Error waiting for child
WaitOnChildErr(std::io::Error),
/// Error waiting for child, abandoned
WaitOnChildErrAbandoned,
/// Error waiting for child, timed out
WaitOnChildErrTimeout,
/// A call to a windows API function returned an error and we didn't know how to handle it
Other(WIN32_ERROR, String, &'static str),
/// We require Windows 7 or greater
Expand Down Expand Up @@ -146,24 +148,25 @@ fn get_last_error(extra: &'static str) -> Error {
Error::Other(code, code_str.to_string(), extra)
}

/// `h` must be a valid handle
unsafe fn clone_handle(h: HANDLE) -> Result<HANDLE> {
let mut target_h = HANDLE::default();
let ret = DuplicateHandle(
GetCurrentProcess(),
h,
GetCurrentProcess(),
&mut target_h,
0,
false,
DUPLICATE_SAME_ACCESS,
);
if ret.0 == 0 {
return Err(get_last_error("clone_handle"));
/// A wrapper around `OpenProcess` that returns a handle with all access rights
unsafe fn handle_from_process_id(process_id: u32) -> Result<HANDLE> {
match OpenProcess(PROCESS_ALL_ACCESS, false, process_id) {
Ok(handle) => Ok(handle),
Err(_) => Err(get_last_error("handle_from_process_id")),
}
Ok(target_h)
}
fn acquire_priviledges() -> Result<()> {

unsafe fn wait_for_process_by_handle(handle: HANDLE) -> Result<()> {
let ret = WaitForSingleObject(handle, 0xFFFFFFFF);
match ret.0 {
0 => Ok(()),
0x00000080 => Err(Error::WaitOnChildErrAbandoned),
0x00000102 => Err(Error::WaitOnChildErrTimeout),
_ => Err(get_last_error("wait_for_process_by_handle")),
}
}

fn acquire_privileges() -> Result<()> {
let mut privs = TOKEN_PRIVILEGES::default();
privs.PrivilegeCount = 1;
privs.Privileges[0].Attributes = SE_PRIVILEGE_ENABLED;
Expand Down Expand Up @@ -195,8 +198,8 @@ fn acquire_priviledges() -> Result<()> {
Ok(())
}
/// SAFETY: is_suspended must only be true if `target_process` is suspended
unsafe fn trace_from_process(
target_process: &mut std::process::Child,
unsafe fn trace_from_process_id(
target_process_id: u32,
is_suspended: bool,
kernel_stacks: bool,
) -> Result<TraceContext> {
Expand All @@ -213,7 +216,7 @@ unsafe fn trace_from_process(
{
return Err(Error::UnsupportedOsVersion);
}
acquire_priviledges()?;
acquire_privileges()?;

// Set the sampling interval
// Only for Win8 or more
Expand Down Expand Up @@ -263,6 +266,7 @@ unsafe fn trace_from_process(
const PROPS_SIZE: usize = size_of::<EVENT_TRACE_PROPERTIES>() + KERNEL_LOGGER_NAMEA_LEN + 1;
#[derive(Clone)]
#[repr(C)]
#[allow(non_camel_case_types)]
struct EVENT_TRACE_PROPERTIES_WITH_STRING {
data: EVENT_TRACE_PROPERTIES,
s: [u8; KERNEL_LOGGER_NAMEA_LEN + 1],
Expand Down Expand Up @@ -341,10 +345,8 @@ unsafe fn trace_from_process(
}
}

let target_pid = target_process.id();
// std Child closes the handle when it drops so we clone it
let target_proc_handle = clone_handle(HANDLE(target_process.as_raw_handle() as isize))?;
let mut context = TraceContext::new(target_proc_handle, target_pid, kernel_stacks)?;
let target_proc_handle = handle_from_process_id(target_process_id)?;
let mut context = TraceContext::new(target_proc_handle, target_process_id, kernel_stacks)?;
//TODO: Do we need to Box the context?

let mut log = EVENT_TRACE_LOGFILEA::default();
Expand Down Expand Up @@ -438,6 +440,7 @@ unsafe fn trace_from_process(
#[repr(C)]
#[derive(Debug)]
#[allow(non_snake_case)]
#[allow(non_camel_case_types)]
struct EVENT_HEADERR {
Size: u16,
HeaderType: u16,
Expand All @@ -456,6 +459,7 @@ unsafe fn trace_from_process(
#[repr(C)]
#[derive(Debug)]
#[allow(non_snake_case)]
#[allow(non_camel_case_types)]
struct EVENT_RECORDD {
EventHeader: EVENT_HEADERR,
BufferContextAnonymousProcessorNumber: u8,
Expand Down Expand Up @@ -520,8 +524,9 @@ unsafe fn trace_from_process(
std::mem::transmute(NtResumeProcess);
NtResumeProcess(context.target_process_handle.0);
}

// Wait for it to end
target_process.wait().map_err(Error::WaitOnChildErr)?;
wait_for_process_by_handle(target_proc_handle)?;
// This unblocks ProcessTrace
let ret = ControlTraceA(
<CONTROLTRACE_HANDLE as Default>::default(),
Expand Down Expand Up @@ -552,14 +557,18 @@ unsafe fn trace_from_process(

/// The sampled results from a process execution
pub struct CollectionResults(TraceContext);
/// Trace an existing child process based only on its process ID (pid).
/// It is recommended that you use `trace_command` instead, since it suspends the process on creation
/// and only resumes it after the trace has started, ensuring that all samples are captured.
pub fn trace_pid(process_id: u32, kernel_stacks: bool) -> Result<CollectionResults> {
let res = unsafe { trace_from_process_id(process_id, false, kernel_stacks) };
res.map(CollectionResults)
}
/// Trace an existing child process.
/// It is recommended that you use `trace_command` instead, since it suspends the process on creation
/// and only resumes it after the trace has started, ensuring that all samples are captured.
pub fn trace_child(
mut process: std::process::Child,
kernel_stacks: bool,
) -> Result<CollectionResults> {
let res = unsafe { trace_from_process(&mut process, false, kernel_stacks) };
pub fn trace_child(process: std::process::Child, kernel_stacks: bool) -> Result<CollectionResults> {
let res = unsafe { trace_from_process_id(process.id(), false, kernel_stacks) };
res.map(CollectionResults)
}
/// Execute `command` and trace it, periodically collecting call stacks.
Expand All @@ -578,7 +587,7 @@ pub fn trace_command(
.creation_flags(CREATE_SUSPENDED.0)
.spawn()
.map_err(Error::SpawnErr)?;
let res = unsafe { trace_from_process(&mut proc, true, kernel_stacks) };
let res = unsafe { trace_from_process_id(proc.id(), true, kernel_stacks) };
if res.is_err() {
// Kill the suspended process if we had some kind of error
let _ = proc.kill();
Expand Down Expand Up @@ -718,7 +727,7 @@ impl<'a> CallStack<'a> {
/// Iterate addresses in this callstack
///
/// This also performs symbol resolution if possible, and tries to find the image (DLL/EXE) it comes from
fn iter_resolved_addresses2<
fn iter_resolved_addresses<
F: for<'b> FnMut(u64, u64, &'b [&'b str], Option<&'b str>) -> Result<()>,
>(
&'a self,
Expand All @@ -742,7 +751,7 @@ impl<'a> CallStack<'a> {
}
let mut symbol_names = symbol_names_storage;

let module = pdb_db.range(..addr).rev().next();
let module = pdb_db.range(..addr).next_back();
let module = match module {
None => {
f(addr, 0, &[], None)?;
Expand Down Expand Up @@ -790,7 +799,7 @@ impl CollectionResults {
let mut v = vec![];

for callstack in self.iter_callstacks() {
callstack.iter_resolved_addresses2(
callstack.iter_resolved_addresses(
&pdb_db,
&mut v,
|address, displacement, symbol_names, image_name| {
Expand Down Expand Up @@ -882,6 +891,7 @@ fn list_kernel_modules() -> Vec<(OsString, u64, u64)> {
#[repr(C)]
#[derive(Debug)]
#[allow(non_snake_case)]
#[allow(non_camel_case_types)]
struct _RTL_PROCESS_MODULE_INFORMATION {
Section: *mut std::ffi::c_void,
MappedBase: *mut std::ffi::c_void,
Expand Down

0 comments on commit daacd76

Please sign in to comment.