From 6753b04af2af0a16a3dab59e4cedeecae4afc537 Mon Sep 17 00:00:00 2001 From: lisazeyen Date: Fri, 19 Apr 2024 09:55:47 +0200 Subject: [PATCH] update snakemake --- scripts/_helpers.py | 148 +++++++++++++++++++++++++++++++------------- 1 file changed, 105 insertions(+), 43 deletions(-) diff --git a/scripts/_helpers.py b/scripts/_helpers.py index 16ec914..95edb2b 100644 --- a/scripts/_helpers.py +++ b/scripts/_helpers.py @@ -55,16 +55,31 @@ def __dir__(self): return dict_keys + obj_attrs -def mock_snakemake(rulename, **wildcards): +def mock_snakemake( + rulename, + root_dir=None, + configfiles=None, + submodule_dir="workflow/submodules/pypsa-eur", + **wildcards, +): """ This function is expected to be executed from the 'scripts'-directory of ' the snakemake project. It returns a snakemake.script.Snakemake object, based on the Snakefile. + If a rule has wildcards, you have to specify them in **wildcards. + Parameters ---------- rulename: str name of the rule for which the snakemake object should be generated + root_dir: str/path-like + path to the root directory of the snakemake project + configfiles: list, str + list of configfiles to be used to update the config + submodule_dir: str, Path + in case PyPSA-Eur is used as a submodule, submodule_dir is + the path of pypsa-eur relative to the project directory. **wildcards: keyword arguments fixing the wildcards. Only necessary if wildcards are needed. @@ -72,48 +87,95 @@ def mock_snakemake(rulename, **wildcards): import os import snakemake as sm - from packaging.version import Version, parse + from pypsa.descriptors import Dict + from snakemake.api import Workflow + from snakemake.common import SNAKEFILE_CHOICES from snakemake.script import Snakemake - - script_dir = Path(__file__).parent.resolve() - assert ( - Path.cwd().resolve() == script_dir - ), f"mock_snakemake has to be run from the repository scripts directory {script_dir}" - os.chdir(script_dir.parent) - for p in sm.SNAKEFILE_CHOICES: - if os.path.exists(p): - snakefile = p - break - kwargs = dict(rerun_triggers=[]) if parse(sm.__version__) > Version("7.7.0") else {} - workflow = sm.Workflow(snakefile, overwrite_configfiles=[], **kwargs) - workflow.include(snakefile) - workflow.global_resources = {} - rule = workflow.get_rule(rulename) - dag = sm.dag.DAG(workflow, rules=[rule]) - wc = Dict(wildcards) - job = sm.jobs.Job(rule, dag, wc) - - def make_accessable(*ios): - for io in ios: - for i in range(len(io)): - io[i] = os.path.abspath(io[i]) - - make_accessable(job.input, job.output, job.log) - snakemake = Snakemake( - job.input, - job.output, - job.params, - job.wildcards, - job.threads, - job.resources, - job.log, - job.dag.workflow.config, - job.rule.name, - None, + from snakemake.settings import ( + ConfigSettings, + DAGSettings, + ResourceSettings, + StorageSettings, + WorkflowSettings, ) - # create log and output dir if not existent - for path in list(snakemake.log) + list(snakemake.output): - Path(path).parent.mkdir(parents=True, exist_ok=True) - os.chdir(script_dir) - return snakemake + script_dir = Path(__file__).parent.resolve() + if root_dir is None: + root_dir = script_dir.parent + else: + root_dir = Path(root_dir).resolve() + + user_in_script_dir = Path.cwd().resolve() == script_dir + if str(submodule_dir) in __file__: + # the submodule_dir path is only need to locate the project dir + os.chdir(Path(__file__[: __file__.find(str(submodule_dir))])) + elif user_in_script_dir: + os.chdir(root_dir) + elif Path.cwd().resolve() != root_dir: + raise RuntimeError( + "mock_snakemake has to be run from the repository root" + f" {root_dir} or scripts directory {script_dir}" + ) + try: + for p in SNAKEFILE_CHOICES: + if os.path.exists(p): + snakefile = p + break + if configfiles is None: + configfiles = [] + elif isinstance(configfiles, str): + configfiles = [configfiles] + + resource_settings = ResourceSettings() + config_settings = ConfigSettings(configfiles=map(Path, configfiles)) + workflow_settings = WorkflowSettings() + storage_settings = StorageSettings() + dag_settings = DAGSettings(rerun_triggers=[]) + workflow = Workflow( + config_settings, + resource_settings, + workflow_settings, + storage_settings, + dag_settings, + storage_provider_settings=dict(), + ) + workflow.include(snakefile) + + if configfiles: + for f in configfiles: + if not os.path.exists(f): + raise FileNotFoundError(f"Config file {f} does not exist.") + workflow.configfile(f) + + workflow.global_resources = {} + rule = workflow.get_rule(rulename) + dag = sm.dag.DAG(workflow, rules=[rule]) + wc = Dict(wildcards) + job = sm.jobs.Job(rule, dag, wc) + + def make_accessable(*ios): + for io in ios: + for i, _ in enumerate(io): + io[i] = os.path.abspath(io[i]) + + make_accessable(job.input, job.output, job.log) + snakemake = Snakemake( + job.input, + job.output, + job.params, + job.wildcards, + job.threads, + job.resources, + job.log, + job.dag.workflow.config, + job.rule.name, + None, + ) + # create log and output dir if not existent + for path in list(snakemake.log) + list(snakemake.output): + Path(path).parent.mkdir(parents=True, exist_ok=True) + + finally: + if user_in_script_dir: + os.chdir(script_dir) + return snakemake \ No newline at end of file