Skip to content

Commit

Permalink
get step2 yaml from step3 result
Browse files Browse the repository at this point in the history
  • Loading branch information
xingzhongyu committed Feb 16, 2024
1 parent 4dd7948 commit 3bf9067
Showing 1 changed file with 47 additions and 0 deletions.
47 changes: 47 additions & 0 deletions dance/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -940,3 +940,50 @@ def generate_subsets(path, tune_mode, save_directory, file_path, log_dir, requir
config_dir = os.path.relpath(os.path.dirname(save_path), os.path.dirname(os.path.join(root_path, file_path)))
command_str = command_str + f"python {file_path} --config_dir={config_dir}/subset_{index}_ --count={count} > {log_dir}/{index}.log 2>&1 &\n"
return command_str, configs


def get_step3_yaml(conf_save_path="examples/tuning/cta_svm/config_yamls/params/",
conf_load_path="examples/tuning/cta_svm/cell_type_annotation_default_params.yaml",
result_load_path="examples/tuning/cta_svm/results/pipeline/best_test_acc.csv",
required_funs=["SetConfig"], required_indexes=[-1], root_path=None):
"""Generate the configuration file of step 3 based on the results of step 2.
Parameters
----------
conf_save_path
Storage directory of the configuration file generated in step 3
conf_load_path
Parameter search range of all preprocessing functions under a specific algorithm task
result_load_path
The storage path of the result of step 2
required_funs
Required functions in step 3
required_indexes
Location of required functions in step 3
root_path
root path of all paths
"""
DANCEDIR = Path(__file__).resolve().parent.parent
root_path = DANCEDIR if root_path is None else root_path
conf_save_path = os.path.join(root_path, conf_save_path)
conf_load_path = os.path.join(root_path, conf_load_path)
result_load_path = os.path.join(root_path, result_load_path)
conf = OmegaConf.load(conf_load_path)
result = pd.read_csv(result_load_path)
columns = sorted([col for col in result.columns if col.startswith("pipeline")])
pipeline_names = result.loc[:, columns].values
count = 0
for row in pipeline_names:
pipeline = []
row = [i for i in row]
for i, f in zip(required_indexes, required_funs):
row.insert(i, f)
for x in row:
for k in conf.pipeline:
if k["target"] == x:
pipeline.append(k)
temp_conf = conf.copy()
temp_conf.pipeline = pipeline
count += 1
OmegaConf.save(temp_conf, f"{conf_save_path}/{count}_test_acc_params_tuning_config.yaml")

0 comments on commit 3bf9067

Please sign in to comment.