Skip to content

Commit

Permalink
Added profiling support for malware scans
Browse files Browse the repository at this point in the history
  • Loading branch information
akenion committed Mar 27, 2024
1 parent 173e16b commit db62725
Show file tree
Hide file tree
Showing 5 changed files with 334 additions and 34 deletions.
6 changes: 6 additions & 0 deletions wordfence/cli/malwarescan/definition.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,12 @@
"context": "CLI",
"argument_type": "FLAG",
"default": False
},
"profile": {
"description": "Profile scan performance",
"context": "CLI",
"argument_type": "FLAG",
"default": False
}
}

Expand Down
1 change: 1 addition & 0 deletions wordfence/cli/malwarescan/malwarescan.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,7 @@ def invoke(self) -> int:
debug=self.config.debug,
logging_initializer=self.context.get_log_settings().apply,
match_engine=match_engine,
profile=self.config.profile
)
if io_manager.should_read_stdin():
options.path_source = io_manager.get_input_reader()
Expand Down
2 changes: 1 addition & 1 deletion wordfence/scanning/matching/pcre.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ def _compile_regexes(self) -> None:
self.signatures_without_common_strings = \
self._extract_signatures_without_common_strings()

def prepare(self) -> None:
def _prepare(self) -> None:
self._compile_regexes()

def create_context(self) -> PcreMatcherContext:
Expand Down
158 changes: 125 additions & 33 deletions wordfence/scanning/scanner.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
from enum import IntEnum
from multiprocessing import Queue, Process, Value, get_start_method
from dataclasses import dataclass
from typing import Set, Optional, Callable, Dict, NamedTuple, Tuple, List
from typing import Set, Optional, Callable, Dict, NamedTuple, Tuple, List, \
Union
from logging import Handler

from .exceptions import ScanningException, ScanningIoException
Expand All @@ -18,6 +19,7 @@
get_all_parents, PathSet
from ..util.units import scale_byte_unit
from ..logging import log, remove_initial_handler, VERBOSE
from ..util.profiling import Profiler, ProfileEvent, EventTimer

MAX_PENDING_FILES = 1000 # Arbitrary limit
MAX_PENDING_RESULTS = 100
Expand Down Expand Up @@ -66,6 +68,54 @@ class ScanEventType(IntEnum):
FATAL_EXCEPTION = 4
PROGRESS_UPDATE = 5
LOG_MESSAGE = 6
PROFILE_EVENT = 7


class ScanEvent:

# TODO: Define custom (more compact) pickle serialization format for this
# class as a potential performance improvement

def __init__(
self,
type: int,
data=None,
worker_index: Optional[int] = None
):
self.type = type
self.data = data
self.worker_index = worker_index


class ScanProfileEvent(ScanEvent):

def __init__(
self,
event: ProfileEvent,
worker_index: Optional[int] = None
):
super().__init__(ScanEventType.PROFILE_EVENT, event, worker_index)


def _put_profile_event(
queue: Queue,
event: Optional[Union[ProfileEvent, EventTimer]]
) -> None:
if event is None:
return
if isinstance(event, EventTimer):
event = event.stop()
queue.put(ScanProfileEvent(event))


def _event_timer(
condition: bool,
name: str,
is_global: bool = False
) -> Optional[EventTimer]:
if not condition:
return None
return EventTimer(name, is_global=is_global)


class EventQueueLogHandler(Handler):
Expand Down Expand Up @@ -111,6 +161,7 @@ class Options:
debug: bool = False
logging_initializer: Callable[[], None] = None
match_engine: MatchEngine = MatchEngine.get_default()
profile: bool = False


class Status(IntEnum):
Expand Down Expand Up @@ -221,7 +272,8 @@ def __init__(
use_log_events: bool = False,
event_queue: Optional[Queue] = None,
allow_io_errors: bool = False,
logging_initializer: Callable[[], None] = None
logging_initializer: Callable[[], None] = None,
profile: bool = False
):
self._input_queue = Queue(input_queue_size)
self.output_queue = Queue(output_queue_size)
Expand All @@ -234,6 +286,7 @@ def __init__(
self._event_queue = event_queue
self.allow_io_errors = allow_io_errors
self._logging_initializer = logging_initializer
self.profile = profile
self._path_count = 0
self._skipped_count = Value('i', 0)
super().__init__(name='file-locator')
Expand All @@ -257,7 +310,16 @@ def finalize_paths(self):
def get_next_file(self):
return self.output_queue.get()

def _put_profile_event(
self,
event: Optional[Union[ProfileEvent, EventTimer]]
) -> None:
if not self.profile:
return
_put_profile_event(self._event_queue, event)

def run(self):
timer = _event_timer(self.profile, 'file_locator_all', is_global=True)
if self._logging_initializer is not None:
self._logging_initializer()
if self._use_log_events:
Expand All @@ -269,6 +331,7 @@ def run(self):
skipped_count = 0
scanned_paths = PathSet()
while (path := self._input_queue.get()) is not None:
path_timer = _event_timer(self.profile, 'file_locator_path')
locator = FileLocator(
path=path,
file_filter=self.file_filter,
Expand All @@ -278,31 +341,17 @@ def run(self):
)
locator.locate()
skipped_count += locator.skipped_count
self._put_profile_event(path_timer)
except ScanningException as exception:
self.output_queue.put(ExceptionContainer(exception))
self._skipped_count.value = skipped_count
self.output_queue.put(None)
self._put_profile_event(timer)

def get_skipped_count(self) -> int:
return self._skipped_count.value


class ScanEvent:

# TODO: Define custom (more compact) pickle serialization format for this
# class as a potential performance improvement

def __init__(
self,
type: int,
data=None,
worker_index: Optional[int] = None
):
self.type = type
self.data = data
self.worker_index = worker_index


class ScanProgressMonitor(Process):

def __init__(self, status: Value, event_queue: Queue):
Expand Down Expand Up @@ -334,7 +383,8 @@ def __init__(
scanned_content_limit: Optional[int] = None,
use_log_events: bool = False,
allow_io_errors: bool = False,
logging_initializer: Callable[[], None] = None
logging_initializer: Callable[[], None] = None,
profile: bool = False
):
self.index = index
self._status = status
Expand All @@ -347,13 +397,20 @@ def __init__(
self._use_log_events = use_log_events
self._allow_io_errors = allow_io_errors
self._logging_initializer = logging_initializer
self._profile = profile
self.complete = Value(c_bool, False)
self._timer = None
super().__init__(name=self._generate_name())

def _generate_name(self) -> str:
return 'worker-' + str(self.index)

def work(self):
self._timer = _event_timer(
self._profile,
'scan_worker',
is_global=True
)
try:
self._working = True
self._matcher.prepare(thread=True)
Expand All @@ -374,12 +431,14 @@ def work(self):
{'exception': item}
)
else:
timer = _event_timer(self._profile, 'process_file')
try:
self._process_file(item, workspace)
except OSError as error:
self._put_io_error(ExceptionContainer(error))
except Exception as error:
self._put_error(ExceptionContainer(error))
self._put_profile_event(timer)
except queue.Empty:
if self._status.value == Status.PROCESSING_FILES:
self._complete()
Expand All @@ -404,9 +463,17 @@ def _put_error(self, error, fatal: bool = True) -> None:
def _put_io_error(self, error) -> None:
self._put_error(error, not self._allow_io_errors)

def _put_profile_event(
self,
event: Optional[Union[ProfileEvent, EventTimer]]
) -> None:
_put_profile_event(self._event_queue, event)

def _complete(self):
self._working = False
self.complete.value = True
if self._timer is not None:
self._put_profile_event(self._timer.stop())
self._put_event(ScanEventType.COMPLETED)

def is_complete(self) -> bool:
Expand All @@ -422,21 +489,29 @@ def _get_next_chunk_size(self, length: int) -> int:

def _process_file(self, path: str, workspace: Optional[MatchWorkspace]):
log.log(VERBOSE, f'Processing file: {path}')
open_timer = _event_timer(self._profile, 'open_file')
with open(path, mode='rb') as file, \
self._matcher.create_context() as context:
self._put_profile_event(open_timer)
length = 0
while (chunk_size := self._get_next_chunk_size(length)):
chunk_timer = _event_timer(self._profile, 'read_chunk')
chunk = file.read(chunk_size)
self._put_profile_event(chunk_timer)
if not chunk:
break
first = length == 0
length += len(chunk)
if context.process_chunk(
chunk,
start=first,
workspace=workspace
):
break
match_timer = _event_timer(self._profile, 'match_chunk')
try:
if context.process_chunk(
chunk,
start=first,
workspace=workspace
):
break
finally:
self._put_profile_event(match_timer)
self._put_event(
ScanEventType.FILE_PROCESSED,
{
Expand Down Expand Up @@ -532,7 +607,9 @@ def __init__(self, elapsed_time: int, metrics: ScanMetrics):

ScanResultCallback = Callable[[ScanResult], None]
ProgressReceiverCallback = Callable[[ScanProgressUpdate], None]
ScanFinishedCallback = Callable[[ScanMetrics, timing.Timer], None]
ScanFinishedCallback = Callable[
[ScanMetrics, timing.Timer, Optional[Profiler]], None
]


class ScanFinishedMessages(NamedTuple):
Expand Down Expand Up @@ -575,7 +652,8 @@ def get_scan_finished_messages(

def default_scan_finished_handler(
metrics: ScanMetrics,
timer: timing.Timer
timer: timing.Timer,
profiler: Optional[Profiler]
) -> None:
"""Used as the default ScanFinishedCallback"""
messages = get_scan_finished_messages(metrics, timer)
Expand All @@ -584,6 +662,8 @@ def default_scan_finished_handler(
if messages.skipped:
log.warning(messages.skipped)
log.info(messages.results)
if profiler is not None:
profiler.output_results()
return messages


Expand All @@ -603,7 +683,8 @@ def __init__(
use_log_events: bool = False,
allow_io_errors: bool = False,
debug: bool = False,
logging_initializer: Callable[[], None] = False
logging_initializer: Callable[[], None] = False,
profiler: Optional[Profiler] = None
):
self.size = size
self._matcher = matcher
Expand All @@ -619,6 +700,7 @@ def __init__(
self._allow_io_errors = allow_io_errors
self._debug = debug
self._logging_initializer = logging_initializer
self._profiler = profiler
self._completed = False

def __enter__(self):
Expand Down Expand Up @@ -678,7 +760,8 @@ def start(self):
self._scanned_content_limit,
self._use_log_events,
self._allow_io_errors,
self._logging_initializer
self._logging_initializer,
self._profiler is not None
)
worker.start()
self._workers.append(worker)
Expand Down Expand Up @@ -775,6 +858,9 @@ def await_results(
elif event.type == ScanEventType.LOG_MESSAGE:
message: str = event.data['message']
log.log(event.data['level'], message)
elif event.type == ScanEventType.PROFILE_EVENT:
if self._profiler is not None:
self._profiler.add_event(event.data)
return False

def is_failed(self) -> bool:
Expand Down Expand Up @@ -808,13 +894,17 @@ def scan(
) -> ScanMetrics:
"""Run a scan"""
timer = timing.Timer()
profiler = Profiler() if self.options.profile else None
event_queue = Queue(MAX_PENDING_RESULTS)
file_locator_process = FileLocatorProcess(
file_filter=self.options.file_filter,
use_log_events=use_log_events,
event_queue=event_queue if use_log_events else None,
event_queue=event_queue if (
use_log_events or profiler is not None
) else None,
allow_io_errors=self.options.allow_io_errors,
logging_initializer=self.options.logging_initializer
logging_initializer=self.options.logging_initializer,
profile=profiler is not None
)
file_locator_process.start()
self.active.append(file_locator_process)
Expand All @@ -839,7 +929,8 @@ def scan(
use_log_events=use_log_events,
allow_io_errors=self.options.allow_io_errors,
debug=self.options.debug,
logging_initializer=self.options.logging_initializer
logging_initializer=self.options.logging_initializer,
profiler=profiler
) as worker_pool:
def add_path(path: str):
while not file_locator_process.add_path(path):
Expand All @@ -861,7 +952,8 @@ def add_path(path: str):
scan_finished_handler = scan_finished_handler if scan_finished_handler\
else default_scan_finished_handler
metrics.skipped_files = file_locator_process.get_skipped_count()
scan_finished_handler(metrics, timer)
profiler.complete()
scan_finished_handler(metrics, timer, profiler)
return (metrics, timer)

def terminate(self) -> None:
Expand Down
Loading

0 comments on commit db62725

Please sign in to comment.