diff --git a/.gitignore b/.gitignore index b6f7329..ac39966 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,6 @@ penv/ env/ +venv/ .vscode/settings.json coco_instances_val2017.json mypythonlib.egg-info/ @@ -19,4 +20,4 @@ road_sign_data.yaml BCCD_Dataset/ model_training/ __pycache__ -samples +samples \ No newline at end of file diff --git a/pylabel/splitter.py b/pylabel/splitter.py index ea016e2..27df0c7 100644 --- a/pylabel/splitter.py +++ b/pylabel/splitter.py @@ -22,8 +22,18 @@ def GroupShuffleSplit( ): """ This function uses the GroupShuffleSplit command from sklearn. It can split into 3 groups (train, - test, and val) by applying the command twice. + test, and val) by applying the command twice. If you want to split into only 2 groups (train and test), + then set val_pct to 0. """ + + # Check inputs and raise errors if needed + assert 0 < float(train_pct) < 1, "train_pct must be between 0 and 1" + assert 0 < float(test_pct) < 1, "test_pct must be between 0 and 1" + # check that the sum of train_pct, test_pct, and val_pct is equal to 1 + assert ( + round(train_pct + test_pct + val_pct, 1) == 1 + ), "Sum of train_pct, test_pct, and val_pct must equal 1." + df_main = self.dataset.df gss = sklearnGroupShuffleSplit( n_splits=1, train_size=train_pct, random_state=random_state @@ -69,6 +79,22 @@ def StratifiedGroupShuffleSplit( train, test, or val. When a split dataset is exported the annotations will be split into seperate groups so that can be used used in model training, testing, and validation. """ + + # Check inputs and raise errors if needed + assert ( + 0 <= float(train_pct) <= 1 + ), "train_pct must be greater than or equal to 0 and less than or equal to 1" + assert ( + 0 <= float(test_pct) <= 1 + ), "test_pct must be greater than or equal to 0 and less than or equal to 1" + assert ( + 0 <= float(val_pct) <= 1 + ), "val_pct must be greater than or equal to 0 and less than or equal to 1" + # check that the sum of train_pct, test_pct, and val_pct is equal to 1 + assert ( + round(train_pct + test_pct + val_pct, 1) == 1 + ), "Sum of train_pct, test_pct, and val_pct must equal 1." + df_main = self.dataset.df df_main = df_main.reindex( np.random.permutation(df_main.index)