diff --git a/core/common/constant.py b/core/common/constant.py index 64eda172..98166d3c 100644 --- a/core/common/constant.py +++ b/core/common/constant.py @@ -51,7 +51,7 @@ class ModuleType(Enum): CLOUDMODEL = "cloudmodel" # Dataset Preprocessor - DATA_PROCESSOR = "dataset_processor" + DATA_PROCESSOR = "dataset_processor" # HEM HARD_EXAMPLE_MINING = "hard_example_mining" diff --git a/core/testcasecontroller/algorithm/paradigm/base.py b/core/testcasecontroller/algorithm/paradigm/base.py index 4b079f5b..144660ab 100644 --- a/core/testcasecontroller/algorithm/paradigm/base.py +++ b/core/testcasecontroller/algorithm/paradigm/base.py @@ -101,8 +101,9 @@ def build_paradigm_job(self, paradigm_type): if paradigm_type == ParadigmType.LIFELONG_LEARNING.value: return LifelongLearning( - estimator=self.module_instances.get( + seen_estimator=self.module_instances.get( ModuleType.BASEMODEL.value), + unseen_estimator=None, task_definition=self.module_instances.get( ModuleType.TASK_DEFINITION.value), task_relationship_discovery=self.module_instances.get( diff --git a/core/testcasecontroller/algorithm/paradigm/joint_inference/joint_inference.py b/core/testcasecontroller/algorithm/paradigm/joint_inference/joint_inference.py index c6fcaa52..5681b449 100644 --- a/core/testcasecontroller/algorithm/paradigm/joint_inference/joint_inference.py +++ b/core/testcasecontroller/algorithm/paradigm/joint_inference/joint_inference.py @@ -50,6 +50,7 @@ class JointInference(ParadigmBase): def __init__(self, workspace, **kwargs): ParadigmBase.__init__(self, workspace, **kwargs) + self.inference_dataset = None self.kwargs = kwargs self.hard_example_mining_mode = kwargs.get( "hard_example_mining_mode", @@ -57,6 +58,11 @@ def __init__(self, workspace, **kwargs): ) def set_config(self): + """Configure output_dir, dataset, modules + + Raises: + KeyError: Required Modules are not fully loaded. + """ inference_output_dir = os.path.dirname(self.workspace) os.environ["RESULT_SAVED_URL"] = inference_output_dir @@ -65,7 +71,7 @@ def set_config(self): LOGGER.info("Loading dataset") self.inference_dataset = self.dataset.load_data( - self.dataset.test_data_info, + self.dataset.test_data_info, "inference" ) @@ -77,12 +83,13 @@ def set_config(self): required_modules = {"edgemodel", "cloudmodel", "hard_example_mining"} if not required_modules.issubset(set(self.module_instances.keys())): - raise ValueError( + raise KeyError( f"Required modules: {required_modules}, " f"but got: {self.module_instances.keys()}" ) - - # if hard example mining is OracleRouter, add the edgemodel and cloudmodel object to its kwargs so that it can use them. + + # if hard example mining is OracleRouter, + # add the edgemodel and cloudmodel object to its kwargs so that it can use them. mining = self.module_instances["hard_example_mining"] param = mining.get("param") if mining.get("method", None) == "OracleRouter": @@ -116,12 +123,13 @@ def _cleanup(self, job): for module in self.module_instances.values(): if hasattr(module, "cleanup"): module.cleanup() - - # Since the hard example mining module is instantiated within the job, special handling is required. + + # Since the hard example mining module is instantiated within the job, + # special call is required. mining_instance = job.hard_example_mining_algorithm if hasattr(mining_instance, "cleanup"): mining_instance.cleanup() - + del job def _inference(self, job): @@ -132,7 +140,7 @@ def _inference(self, job): LOGGER.info("Inference Start") pbar = tqdm( - self.inference_dataset.x, + self.inference_dataset.x, total=len(self.inference_dataset.x), ncols=100 ) diff --git a/core/testenvmanager/dataset/dataset.py b/core/testenvmanager/dataset/dataset.py index 0a1d04e6..9b7070cc 100644 --- a/core/testenvmanager/dataset/dataset.py +++ b/core/testenvmanager/dataset/dataset.py @@ -29,7 +29,7 @@ from core.common import utils from core.common.constant import DatasetFormat - +# pylint: disable=too-many-instance-attributes class Dataset: """ Data: @@ -420,7 +420,8 @@ def _hard_example_splitting(self, data_file, data_format, ratio, return data_files @classmethod - def load_data(cls, file: str, data_type: str, label=None, use_raw=False, feature_process=None, **kwargs): + def load_data(cls, file: str, data_type: str, label=None, + use_raw=False, feature_process=None, **kwargs): """ load data