Skip to content

Commit

Permalink
Update test_splitting.py
Browse files Browse the repository at this point in the history
Adjusted the test cases to the new OO-implementation of splitting.py
  • Loading branch information
haidarjomaa authored Jan 23, 2024
1 parent 5ae126d commit b07936e
Showing 1 changed file with 9 additions and 4 deletions.
13 changes: 9 additions & 4 deletions data_processing/test_splitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,26 +14,31 @@ def setUp(self):
'label': np.random.choice([0, 1], size=num_samples)
}
self.df = pd.DataFrame(data)
self.splitter = None

def test_split_k_folds(self):
df_result = split_K_stratified_folds(self.df, nfolds=5, seed=42, id_key='id', split_key='split_key', label_keys=['label'])
self.splitter = Splitting(self.df, nfolds=5, seed=42, id_key='id', split_key='split_key', label_keys=['label'])
df_result = self.splitter.split_K_stratified_folds()
unique_folds = df_result['fold'].unique()

self.assertEqual(len(unique_folds), 5) # Check if correct number of folds are created

def test_split_k_folds_verbose(self):
df_result = split_K_stratified_folds(self.df, nfolds=5, seed=42, id_key='id', split_key='split_key', label_keys=['label'], verbose=True)
self.splitter = Splitting(self.df, nfolds=5, seed=42, id_key='id', split_key='split_key', label_keys=['label'], verbose=True)
df_result =self.splitter.split_K_stratified_folds()

unique_folds = df_result['fold'].unique()

self.assertEqual(len(unique_folds), 5) # Check if correct number of folds are created

def test_split_k_folds_labels(self):
df_result = split_K_stratified_folds(self.df, nfolds=5, seed=42, id_key='id', split_key='split_key', label_keys=['label'])
self.splitter = Splitting(self.df, nfolds=5, seed=42, id_key='id', split_key='split_key', label_keys=['label'])
df_result =self.splitter.split_K_stratified_folds()
unique_labels = df_result.set_index(['fold', 'label']).groupby(level='label').count()

self.assertTrue(unique_labels.min()['id'] > 1) # Check if each label has samples in each fold

def test_split_k_folds_reset_index(self):
df_result = split_K_stratified_folds(self.df, nfolds=5, seed=42, id_key='id', split_key='split_key', label_keys=['label'])
self.splitter = Splitting(self.df, nfolds=5, seed=42, id_key='id', split_key='split_key', label_keys=['label'])
df_result =self.splitter.split_K_stratified_folds()
self.assertTrue('id' in df_result.columns) # Check if 'id' column is present after resetting index

0 comments on commit b07936e

Please sign in to comment.