forked from rapidsai/rapids-dask-dependency
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Generalize patching to support vendoring (rapidsai#39)
This PR makes a number of significant changes to the patching infrastructure: 1. This PR reorganizes the patching logic to be based on a more principled approach. Rather than maintaining lists of patch functions that are each responsible for filtering modules to apply themselves to, patches are organized in the patches directory in a tree structure matching dask itself. Patches are found and run by importing the same relative paths within the `patches` directory corresponding to a particular dask or distributed module. 2. It adds proper support for patching submodules. Previously the loader was being disabled whenever a real dask module was being imported, but this is problematic because if some dask modules import others they will pre-populate `sys.modules` with the real modules and therefore the loader will never be used for loading a patched version of the submodule. 3. Patches are no longer just functions applied to modules, they are arbitrary functions executed when a module is imported. As a result, a wider range of modifications is possible than was previously allowed. In particular: 4. The more general functions allow the entire module import to be hijacked and redirected to other modules. 5. The new framework is used to vendor a patched version of the accessor.py module in dask, which resolves the issues observed in dask/dask#11035. Authors: - Vyas Ramasubramani (https://github.com/vyasr) Approvers: - Richard (Rick) Zamora (https://github.com/rjzamora) URL: rapidsai#39
- Loading branch information
Showing
10 changed files
with
597 additions
and
53 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
# Copyright (c) 2024, NVIDIA CORPORATION. | ||
|
||
import importlib | ||
import importlib.util | ||
from abc import abstractmethod | ||
|
||
from rapids_dask_dependency.utils import patch_warning_stacklevel, update_spec | ||
|
||
|
||
class BaseImporter: | ||
@abstractmethod | ||
def load_module(self, spec): | ||
pass | ||
|
||
|
||
class MonkeyPatchImporter(BaseImporter): | ||
"""The base importer for modules that are monkey-patched.""" | ||
|
||
def __init__(self, name, patch_func): | ||
self.name = name.replace("rapids_dask_dependency.patches.", "") | ||
self.patch_func = patch_func | ||
|
||
def load_module(self, spec): | ||
# Four extra stack frames: 1) DaskLoader.create_module, 2) | ||
# MonkeyPatchImporter.load_module, 3) importlib.import_module, and 4) the | ||
# patched warnings function (not including the internal frames, which warnings | ||
# ignores). | ||
with patch_warning_stacklevel(4): | ||
mod = importlib.import_module(self.name) | ||
self.patch_func(mod) | ||
update_spec(spec, mod.__spec__) | ||
mod._rapids_patched = True | ||
return mod | ||
|
||
|
||
class VendoredImporter(BaseImporter): | ||
"""The base importer for vendored modules.""" | ||
|
||
# Vendored files use a standard prefix to avoid name collisions. | ||
default_prefix = "__rdd_patch_" | ||
|
||
def __init__(self, module): | ||
self.real_module_name = module.replace("rapids_dask_dependency.patches.", "") | ||
module_parts = module.split(".") | ||
module_parts[-1] = self.default_prefix + module_parts[-1] | ||
self.vendored_module_name = ".".join(module_parts) | ||
|
||
def load_module(self, spec): | ||
vendored_module = importlib.import_module(self.vendored_module_name) | ||
# At this stage the module loader must have been disabled for this module, so we | ||
# can access the original module. We don't want to actually import it, we just | ||
# want enough information on it to update the spec. | ||
original_spec = importlib.util.find_spec(self.real_module_name) | ||
update_spec(spec, original_spec) | ||
return vendored_module |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,6 @@ | ||
# Copyright (c) 2024, NVIDIA CORPORATION. | ||
|
||
from .add_patch_attr import add_patch_attr | ||
from rapids_dask_dependency.importer import MonkeyPatchImporter | ||
|
||
patches = [add_patch_attr] | ||
_importer = MonkeyPatchImporter(__name__, lambda _: None) | ||
load_module = _importer.load_module |
This file was deleted.
Oops, something went wrong.
Oops, something went wrong.