Skip to content

Commit

Permalink
core: better support for ad-hoc configs
Browse files Browse the repository at this point in the history
properly reload/unload the relevant modules so hopefully no more weird hacks should be required

relevant
- karlicoss/promnesia#340
- #46
  • Loading branch information
karlicoss committed Feb 9, 2023
1 parent fb0c128 commit b30e681
Show file tree
Hide file tree
Showing 3 changed files with 119 additions and 5 deletions.
73 changes: 68 additions & 5 deletions my/core/cfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,75 @@ def override_config(config: F) -> Iterator[F]:
delattr(config, k)


# helper for tests? not sure if could be useful elsewhere
import importlib
import sys
from types import ModuleType
from typing import Optional, Union, Set
ModuleRegex = str
ModuleIsh = Union[ModuleRegex, ModuleType]
# TODO get rid of ModuleType... could add later?
@contextmanager
def tmp_config():
import my.config as C
with override_config(C):
yield C # todo not sure?
def _reload_modules(module: ModuleRegex) -> Iterator[None]:
def loaded_modules() -> Set[str]:
return {name for name in sys.modules if re.fullmatch(module, name)}

modules_before = loaded_modules()

for m in modules_before:
print("RELOADING!!!", module, m)
importlib.reload(sys.modules[m])

try:
yield
finally:
modules_after = loaded_modules()
for m in modules_after:
if m in modules_before:
# was previously loaded, so need to reload to pick up old config
print("UNLOADING!!", module, m)
importlib.reload(sys.modules[m])
else:
# wasn't previously loaded, so need to unload it
# otherwise it might fail due to missing config etc
print("DELETING!", module, m)
sys.modules.pop(m, None)


from contextlib import ExitStack
import re
from typing import Sequence
@contextmanager
def tmp_config(*, module: Optional[ModuleIsh]=None, config=None):
if module is None:
assert config is None
if module is not None:
assert config is not None

import my.config
with ExitStack() as module_reload_stack, override_config(my.config) as new_config:
if config is not None:
overrides = {k: v for k, v in vars(config).items() if not k.startswith('__')}
for k, v in overrides.items():
setattr(new_config, k, v)

if module is not None:
# modules: Sequence[Union[ModuleName, ModuleType]] = []
# if isinstance(module, str):
# modules = [name for name in sorted(sys.modules) if re.fullmatch(module, name)]
# else:
# modules = [module]

# print("MODULES TO UNLOAD!!!", modules)
# TODO always pass actual module object??
# for mod in modules:
# module_reload_stack.enter_context(_reload_modules(mod))
# TODO hmm why this works??
module_reload_stack.enter_context(_reload_modules(module)) # TODO rename to modules?
# module_reload_stack.enter_context(_reload_modules('my.browser'))
# module_reload_stack.enter_context(_reload_modules('my.browser.all'))
# module_reload_stack.enter_context(_reload_modules('my.browser.active_browser'))
# module_reload_stack.enter_context(_reload_modules('my.browser.export'))
yield new_config


def test_tmp_config() -> None:
Expand Down
21 changes: 21 additions & 0 deletions my/simple.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
'''
Just a demo module for testing and documentation purposes
'''
from dataclasses import dataclass
from typing import Iterator

from my.core import make_config

from my.config import simple as user_config


@dataclass
class simple(user_config):
count: int


config = make_config(simple)


def items() -> Iterator[int]:
yield from range(config.count)
30 changes: 30 additions & 0 deletions tests/test_tmp_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from pathlib import Path
import tempfile

from my.core.cfg import tmp_config

import pytest


def _init_default_config():
import my.config
class default_config:
count = 5
my.config.simple = default_config # type: ignore[attr-defined]
_init_default_config()


from my.simple import items


def test_tmp_config() -> None:
assert len(list(items())) == 5

class config:
class simple:
count = 3

with tmp_config(modules='my.simple', config=config):
assert len(list(items())) == 3

assert len(list(items())) == 5

0 comments on commit b30e681

Please sign in to comment.