Skip to content

Commit

Permalink
Refactored and reformatted all the files according to the pre-commit …
Browse files Browse the repository at this point in the history
…configuration.
  • Loading branch information
Sai-Suraj-27 committed Aug 15, 2023
1 parent f209cb8 commit 173e614
Show file tree
Hide file tree
Showing 269 changed files with 12,754 additions and 8,193 deletions.
4 changes: 1 addition & 3 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,8 @@ repos:
- id: trailing-whitespace
- id: end-of-file-fixer
- id: check-yaml

- id: check-added-large-files
- repo: https://github.com/psf/black
rev: 23.7.0
hooks:
- id: black
args: # arguments to configure black
- --line-length=100
6 changes: 3 additions & 3 deletions README_ospp.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
# OSPP
I changed the sedna source code to implement my algorithm.
Please turn to https://github.com/kubeedge/sedna/pull/378 and https://github.com/nailtu30/sedna/blob/ospp-final/README_ospp.md for more information.
# OSPP
I changed the sedna source code to implement my algorithm.
Please turn to https://github.com/kubeedge/sedna/pull/378 and https://github.com/nailtu30/sedna/blob/ospp-final/README_ospp.md for more information.
15 changes: 9 additions & 6 deletions core/cmd/benchmarking.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,13 @@

"""main"""

import sys
import argparse
import sys

from core.common.log import LOGGER
from core.common import utils
from core.cmd.obj import BenchmarkingJob
from core.__version__ import __version__
from core.cmd.obj import BenchmarkingJob
from core.common import utils
from core.common.log import LOGGER


def main():
Expand All @@ -30,7 +30,9 @@ def main():
args = parser.parse_args()
config_file = args.benchmarking_config_file
if not utils.is_local_file(config_file):
raise SystemExit(f"not found benchmarking config({config_file}) file in local")
raise SystemExit(
f"not found benchmarking config({config_file}) file in local"
)

config = utils.yaml2dict(args.benchmarking_config_file)
job = BenchmarkingJob(config[str.lower(BenchmarkingJob.__name__)])
Expand All @@ -50,7 +52,8 @@ def _generate_parser():
"--benchmarking_config_file",
nargs="?",
type=str,
help="run a benchmarking job, " "and the benchmarking config file must be yaml/yml file.",
help="run a benchmarking job, "
"and the benchmarking config file must be yaml/yml file.",
)

parser.add_argument(
Expand Down
19 changes: 13 additions & 6 deletions core/cmd/obj/benchmarkingjob.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,11 @@

from core.common import utils
from core.common.constant import TestObjectType
from core.testenvmanager.testenv import TestEnv
from core.storymanager.rank import Rank
from core.testcasecontroller.simulation import Simulation
from core.testcasecontroller.simulation_system_admin import build_simulation_enviroment
from core.testcasecontroller.testcasecontroller import TestCaseController
from core.testenvmanager.testenv import TestEnv


# pylint: disable=too-few-public-methods
Expand Down Expand Up @@ -52,15 +52,19 @@ def __init__(self, config):
def _check_fields(self):
if not self.name and not isinstance(self.name, str):
raise ValueError(
f"benchmarkingjob's name({self.name}) must be provided" f" and be string type."
f"benchmarkingjob's name({self.name}) must be provided"
f" and be string type."
)

if not isinstance(self.workspace, str):
raise ValueError(f"benchmarkingjob's workspace({self.workspace}) must be string type.")
raise ValueError(
f"benchmarkingjob's workspace({self.workspace}) must be string type."
)

if not self.test_object and not isinstance(self.test_object, dict):
raise ValueError(
f"benchmarkingjob's test_object({self.test_object})" f" must be dict type."
f"benchmarkingjob's test_object({self.test_object})"
f" must be dict type."
)

test_object_types = [e.value for e in TestObjectType.__members__.values()]
Expand All @@ -73,7 +77,8 @@ def _check_fields(self):

if not self.test_object.get(test_object_type):
raise ValueError(
f"benchmarkingjob' test_object doesn't find" f" the field({test_object_type})."
f"benchmarkingjob' test_object doesn't find"
f" the field({test_object_type})."
)

def run(self):
Expand All @@ -95,7 +100,9 @@ def run(self):
test_env=self.test_env, test_object=self.test_object
)

succeed_testcases, test_results = self.testcase_controller.run_testcases(self.workspace)
succeed_testcases, test_results = self.testcase_controller.run_testcases(
self.workspace
)

if test_results:
self.rank.save(succeed_testcases, test_results, output_dir=self.workspace)
Expand Down
1 change: 1 addition & 0 deletions core/common/constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ class SystemMetricType(Enum):
"""
System metric type of ianvs.
"""

# pylint: disable=C0103
SAMPLES_TRANSFER_RATIO = "samples_transfer_ratio"
FWT = "FWT"
Expand Down
1 change: 1 addition & 0 deletions core/common/log.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"""Base logger"""

import logging

import colorlog


Expand Down
6 changes: 4 additions & 2 deletions core/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@
import os
import sys
import time

from importlib import import_module
from inspect import getfullargspec

import yaml


Expand Down Expand Up @@ -63,7 +63,9 @@ def py2dict(url):
mod = import_module(module_name)
sys.path.pop(0)
raw_dict = {
name: value for name, value in mod.__dict__.items() if not name.startswith("__")
name: value
for name, value in mod.__dict__.items()
if not name.startswith("__")
}
sys.modules.pop(module_name)

Expand Down
34 changes: 25 additions & 9 deletions core/storymanager/rank/rank.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import pandas as pd

from core.common import utils
from core.storymanager.visualization import get_visualization_func, draw_heatmap_picture
from core.storymanager.visualization import draw_heatmap_picture, get_visualization_func


# pylint: disable=R0902
Expand Down Expand Up @@ -59,11 +59,14 @@ def _parse_config(self, config):

def _check_fields(self):
if not self.sort_by and not isinstance(self.sort_by, list):
raise ValueError(f"rank's sort_by({self.sort_by}) must be provided and be list type.")
raise ValueError(
f"rank's sort_by({self.sort_by}) must be provided and be list type."
)

if not self.visualization and not isinstance(self.visualization, dict):
raise ValueError(
f"rank's visualization({self.visualization}) " f"must be provided and be dict type."
f"rank's visualization({self.visualization}) "
f"must be provided and be dict type."
)

if not self.selected_dataitem and not isinstance(self.selected_dataitem, dict):
Expand All @@ -83,7 +86,8 @@ def _check_fields(self):

if not self.save_mode and not isinstance(self.save_mode, list):
raise ValueError(
f"rank's save_mode({self.save_mode}) " f"must be provided and be list type."
f"rank's save_mode({self.save_mode}) "
f"must be provided and be list type."
)

@classmethod
Expand Down Expand Up @@ -149,7 +153,9 @@ def _get_all(self, test_cases, test_results) -> pd.DataFrame:
all_df.loc[i][0] = algorithm.name
# fill metric columns of algorithm
for metric_name in test_results[test_case.id][0]:
all_df.loc[i][metric_name] = test_results[test_case.id][0].get(metric_name)
all_df.loc[i][metric_name] = test_results[test_case.id][0].get(
metric_name
)

# file paradigm column of algorithm
all_df.loc[i]["paradigm"] = algorithm.paradigm_type
Expand Down Expand Up @@ -192,7 +198,15 @@ def _get_selected(self, test_cases, test_results) -> pd.DataFrame:
if metric_names == ["all"]:
metric_names = self._get_all_metric_names(test_results)

header = ["algorithm", *metric_names, "paradigm", *module_types, *hps_names, "time", "url"]
header = [
"algorithm",
*metric_names,
"paradigm",
*module_types,
*hps_names,
"time",
"url",
]

all_df = copy.deepcopy(self.all_df)
selected_df = pd.DataFrame(all_df, columns=header)
Expand All @@ -207,15 +221,17 @@ def _save_selected(self, test_cases, test_results):
# pylint: disable=E1101
selected_df = self._get_selected(test_cases, test_results)
selected_df.index = pd.np.arange(1, len(selected_df) + 1)
selected_df.to_csv(self.selected_rank_file, index_label="rank", encoding="utf-8", sep=" ")
selected_df.to_csv(
self.selected_rank_file, index_label="rank", encoding="utf-8", sep=" "
)

def _draw_pictures(self, test_cases, test_results):
# pylint: disable=E1101
for test_case in test_cases:
out_put = test_case.output_dir
test_result = test_results[test_case.id][0]
matrix = test_result.get('Matrix')
#print(out_put)
matrix = test_result.get("Matrix")
# print(out_put)
for key in matrix.keys():
draw_heatmap_picture(out_put, key, matrix[key])

Expand Down
2 changes: 1 addition & 1 deletion core/storymanager/visualization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,4 @@
# limitations under the License.

# pylint: disable=missing-module-docstring
from .visualization import get_visualization_func, draw_heatmap_picture
from .visualization import draw_heatmap_picture, get_visualization_func
21 changes: 14 additions & 7 deletions core/storymanager/visualization/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,9 @@

"""Visualization"""

import sys
import os
import sys

import matplotlib.pyplot as plt
from prettytable import from_csv

Expand All @@ -26,24 +27,30 @@ def print_table(rank_file):
table = from_csv(file)
print(table)


def draw_heatmap_picture(output, title, matrix):
"""
draw heatmap for results
"""
plt.figure(figsize=(10, 8), dpi=80)
plt.imshow(matrix, cmap='bwr', extent=(0.5, len(matrix)+0.5, 0.5, len(matrix)+0.5),
origin='lower')
plt.imshow(
matrix,
cmap="bwr",
extent=(0.5, len(matrix) + 0.5, 0.5, len(matrix) + 0.5),
origin="lower",
)
plt.xticks(fontsize=15)
plt.yticks(fontsize=15)
plt.xlabel('task round', fontsize=15)
plt.ylabel('task', fontsize=15)
plt.xlabel("task round", fontsize=15)
plt.ylabel("task", fontsize=15)
plt.title(title, fontsize=15)
plt.colorbar(format='%.2f')
plt.colorbar(format="%.2f")
output_dir = os.path.join(output, f"output/{title}-heatmap.png")
#print(output_dir)
# print(output_dir)
plt.savefig(output_dir)
plt.show()


def get_visualization_func(mode):
"""get visualization func"""
return getattr(sys.modules[__name__], mode)
20 changes: 13 additions & 7 deletions core/testcasecontroller/algorithm/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,14 @@
from core.common.utils import load_module
from core.testcasecontroller.algorithm.module import Module
from core.testcasecontroller.algorithm.paradigm import (
SingleTaskLearning,
IncrementalLearning,
MultiedgeInference,
LifelongLearning,
MultiedgeInference,
SingleTaskLearning,
)
from core.testcasecontroller.generation_assistant import get_full_combinations


class Algorithm:
# pylint: disable=too-many-instance-attributes
# pylint: disable=too-few-public-methods
Expand Down Expand Up @@ -62,7 +63,7 @@ def __init__(self, name, config):
}
self.lifelong_learning_data_setting: dict = {
"train_ratio": 0.8,
"splitting_method": "default"
"splitting_method": "default",
}
self.initial_model_url: str = ""
self.modules: list = []
Expand Down Expand Up @@ -108,7 +109,9 @@ def paradigm(self, workspace: str, **kwargs):

def _check_fields(self):
if not self.name and not isinstance(self.name, str):
raise ValueError(f"algorithm name({self.name}) must be provided and be string type.")
raise ValueError(
f"algorithm name({self.name}) must be provided and be string type."
)

if not self.paradigm_type and not isinstance(self.paradigm_type, str):
raise ValueError(
Expand All @@ -131,7 +134,8 @@ def _check_fields(self):
if not isinstance(self.lifelong_learning_data_setting, dict):
raise ValueError(
f"algorithm lifelong_learning_data_setting"
f"({self.lifelong_learning_data_setting} must be dictionary type.")
f"({self.lifelong_learning_data_setting} must be dictionary type."
)

if not isinstance(self.initial_model_url, str):
raise ValueError(
Expand Down Expand Up @@ -184,5 +188,7 @@ def _load_third_party_packages(self):
try:
load_module(url)
except Exception as err:
raise RuntimeError(f"load third party packages(name={name}, url={url}) failed,"
f" error: {err}.") from err
raise RuntimeError(
f"load third party packages(name={name}, url={url}) failed,"
f" error: {err}."
) from err
Loading

0 comments on commit 173e614

Please sign in to comment.