From b07936e249599b607512212aee3f654d568e9590 Mon Sep 17 00:00:00 2001 From: Haidar Jomaa <130698588+haidarjomaa@users.noreply.github.com> Date: Tue, 23 Jan 2024 22:58:37 +0200 Subject: [PATCH] Update test_splitting.py Adjusted the test cases to the new OO-implementation of splitting.py --- data_processing/test_splitting.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/data_processing/test_splitting.py b/data_processing/test_splitting.py index 28b65d8..cafba85 100644 --- a/data_processing/test_splitting.py +++ b/data_processing/test_splitting.py @@ -14,26 +14,31 @@ def setUp(self): 'label': np.random.choice([0, 1], size=num_samples) } self.df = pd.DataFrame(data) + self.splitter = None def test_split_k_folds(self): - df_result = split_K_stratified_folds(self.df, nfolds=5, seed=42, id_key='id', split_key='split_key', label_keys=['label']) + self.splitter = Splitting(self.df, nfolds=5, seed=42, id_key='id', split_key='split_key', label_keys=['label']) + df_result = self.splitter.split_K_stratified_folds() unique_folds = df_result['fold'].unique() self.assertEqual(len(unique_folds), 5) # Check if correct number of folds are created def test_split_k_folds_verbose(self): - df_result = split_K_stratified_folds(self.df, nfolds=5, seed=42, id_key='id', split_key='split_key', label_keys=['label'], verbose=True) + self.splitter = Splitting(self.df, nfolds=5, seed=42, id_key='id', split_key='split_key', label_keys=['label'], verbose=True) + df_result =self.splitter.split_K_stratified_folds() unique_folds = df_result['fold'].unique() self.assertEqual(len(unique_folds), 5) # Check if correct number of folds are created def test_split_k_folds_labels(self): - df_result = split_K_stratified_folds(self.df, nfolds=5, seed=42, id_key='id', split_key='split_key', label_keys=['label']) + self.splitter = Splitting(self.df, nfolds=5, seed=42, id_key='id', split_key='split_key', label_keys=['label']) + df_result =self.splitter.split_K_stratified_folds() unique_labels = df_result.set_index(['fold', 'label']).groupby(level='label').count() self.assertTrue(unique_labels.min()['id'] > 1) # Check if each label has samples in each fold def test_split_k_folds_reset_index(self): - df_result = split_K_stratified_folds(self.df, nfolds=5, seed=42, id_key='id', split_key='split_key', label_keys=['label']) + self.splitter = Splitting(self.df, nfolds=5, seed=42, id_key='id', split_key='split_key', label_keys=['label']) + df_result =self.splitter.split_K_stratified_folds() self.assertTrue('id' in df_result.columns) # Check if 'id' column is present after resetting index