Skip to content

Commit

Permalink
Generalize patching to support vendoring (#39)
Browse files Browse the repository at this point in the history
This PR makes a number of significant changes to the patching infrastructure:
1. This PR reorganizes the patching logic to be based on a more principled approach. Rather than maintaining lists of patch functions that are each responsible for filtering modules to apply themselves to, patches are organized in the patches directory in a tree structure matching dask itself. Patches are found and run by importing the same relative paths within the `patches` directory corresponding to a particular dask or distributed module.
2. It adds proper support for patching submodules. Previously the loader was being disabled whenever a real dask module was being imported, but this is problematic because if some dask modules import others they will pre-populate `sys.modules` with the real modules and therefore the loader will never be used for loading a patched version of the submodule.
3. Patches are no longer just functions applied to modules, they are arbitrary functions executed when a module is imported. As a result, a wider range of modifications is possible than was previously allowed. In particular:
4. The more general functions allow the entire module import to be hijacked and redirected to other modules.
5. The new framework is used to vendor a patched version of the accessor.py module in dask, which resolves the issues observed in dask/dask#11035.

Authors:
  - Vyas Ramasubramani (https://github.com/vyasr)

Approvers:
  - Richard (Rick) Zamora (https://github.com/rjzamora)

URL: #39
  • Loading branch information
vyasr authored Apr 4, 2024
1 parent b323830 commit a529757
Show file tree
Hide file tree
Showing 10 changed files with 597 additions and 53 deletions.
3 changes: 3 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
# Copyright (c) 2024, NVIDIA CORPORATION.

# Always exclude vendored files from linting
exclude: ".*__rdd_patch_.*"

repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.5.0
Expand Down
73 changes: 34 additions & 39 deletions rapids_dask_dependency/dask_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,65 +3,60 @@
import importlib
import importlib.abc
import importlib.machinery
import importlib.util
import sys
import warnings
from contextlib import contextmanager

from .patches.dask import patches as dask_patches
from .patches.distributed import patches as distributed_patches

original_warn = warnings.warn


def _warning_with_increased_stacklevel(
message, category=None, stacklevel=1, source=None, **kwargs
):
# Patch warnings to have the right stacklevel
# Add 3 to the stacklevel to account for the 3 extra frames added by the loader: one
# in this warnings function, one in the actual loader, and one in the importlib
# call (not including all internal frames).
original_warn(message, category, stacklevel + 3, source, **kwargs)


@contextmanager
def patch_warning_stacklevel():
warnings.warn = _warning_with_increased_stacklevel
yield
warnings.warn = original_warn
from rapids_dask_dependency.utils import patch_warning_stacklevel, update_spec


class DaskLoader(importlib.abc.MetaPathFinder, importlib.abc.Loader):
def __init__(self):
self._blocklist = set()

def create_module(self, spec):
if spec.name.startswith("dask") or spec.name.startswith("distributed"):
with self.disable(), patch_warning_stacklevel():
mod = importlib.import_module(spec.name)
with self.disable(spec.name):
try:
# Absolute import is important here to avoid shadowing the real dask
# and distributed modules in sys.modules. Bad things will happen if
# we use relative imports here.
proxy = importlib.import_module(
f"rapids_dask_dependency.patches.{spec.name}"
)
if hasattr(proxy, "load_module"):
return proxy.load_module(spec)
except ModuleNotFoundError:
pass

# Note: The spec does not make it clear whether we're guaranteed that spec
# is not a copy of the original spec, but that is the case for now. We need
# to assign this because the spec is used to update module attributes after
# it is initialized by create_module.
spec.origin = mod.__spec__.origin
spec.submodule_search_locations = mod.__spec__.submodule_search_locations
# Three extra stack frames: 1) DaskLoader.create_module,
# 2) importlib.import_module, and 3) the patched warnings function (not
# including the internal frames, which warnings ignores).
with patch_warning_stacklevel(3):
mod = importlib.import_module(spec.name)

# TODO: I assume we'll want to only apply patches to specific submodules,
# that'll be up to RAPIDS dask devs to decide.
patches = dask_patches if "dask" in spec.name else distributed_patches
for patch in patches:
patch(mod)
return mod
update_spec(spec, mod.__spec__)
return mod

def exec_module(self, _):
pass

@contextmanager
def disable(self):
sys.meta_path.remove(self)
def disable(self, name):
# This is a context manager that prevents this finder from intercepting calls to
# import a specific name. We must do this to avoid infinite recursion when
# calling import_module in create_module. However, we cannot blanket disable the
# finder because that causes it to be bypassed when transitive imports occur
# within import_module.
try:
self._blocklist.add(name)
yield
finally:
sys.meta_path.insert(0, self)
self._blocklist.remove(name)

def find_spec(self, fullname: str, _, __=None):
if fullname in self._blocklist:
return None
if (
fullname in ("dask", "distributed")
or fullname.startswith("dask.")
Expand Down
55 changes: 55 additions & 0 deletions rapids_dask_dependency/importer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# Copyright (c) 2024, NVIDIA CORPORATION.

import importlib
import importlib.util
from abc import abstractmethod

from rapids_dask_dependency.utils import patch_warning_stacklevel, update_spec


class BaseImporter:
@abstractmethod
def load_module(self, spec):
pass


class MonkeyPatchImporter(BaseImporter):
"""The base importer for modules that are monkey-patched."""

def __init__(self, name, patch_func):
self.name = name.replace("rapids_dask_dependency.patches.", "")
self.patch_func = patch_func

def load_module(self, spec):
# Four extra stack frames: 1) DaskLoader.create_module, 2)
# MonkeyPatchImporter.load_module, 3) importlib.import_module, and 4) the
# patched warnings function (not including the internal frames, which warnings
# ignores).
with patch_warning_stacklevel(4):
mod = importlib.import_module(self.name)
self.patch_func(mod)
update_spec(spec, mod.__spec__)
mod._rapids_patched = True
return mod


class VendoredImporter(BaseImporter):
"""The base importer for vendored modules."""

# Vendored files use a standard prefix to avoid name collisions.
default_prefix = "__rdd_patch_"

def __init__(self, module):
self.real_module_name = module.replace("rapids_dask_dependency.patches.", "")
module_parts = module.split(".")
module_parts[-1] = self.default_prefix + module_parts[-1]
self.vendored_module_name = ".".join(module_parts)

def load_module(self, spec):
vendored_module = importlib.import_module(self.vendored_module_name)
# At this stage the module loader must have been disabled for this module, so we
# can access the original module. We don't want to actually import it, we just
# want enough information on it to update the spec.
original_spec = importlib.util.find_spec(self.real_module_name)
update_spec(spec, original_spec)
return vendored_module
5 changes: 3 additions & 2 deletions rapids_dask_dependency/patches/dask/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) 2024, NVIDIA CORPORATION.

from .add_patch_attr import add_patch_attr
from rapids_dask_dependency.importer import MonkeyPatchImporter

patches = [add_patch_attr]
_importer = MonkeyPatchImporter(__name__, lambda _: None)
load_module = _importer.load_module
5 changes: 0 additions & 5 deletions rapids_dask_dependency/patches/dask/add_patch_attr.py

This file was deleted.

Loading

0 comments on commit a529757

Please sign in to comment.