Skip to content

Commit

Permalink
Merge pull request #135 from alexheat/dev
Browse files Browse the repository at this point in the history
Improve handling for splitter issue #134
  • Loading branch information
alexheat committed Nov 21, 2023
2 parents 316fd6c + 4d028f3 commit db0bf7c
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 2 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
penv/
env/
venv/
.vscode/settings.json
coco_instances_val2017.json
mypythonlib.egg-info/
Expand All @@ -19,4 +20,4 @@ road_sign_data.yaml
BCCD_Dataset/
model_training/
__pycache__
samples
samples
28 changes: 27 additions & 1 deletion pylabel/splitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,18 @@ def GroupShuffleSplit(
):
"""
This function uses the GroupShuffleSplit command from sklearn. It can split into 3 groups (train,
test, and val) by applying the command twice.
test, and val) by applying the command twice. If you want to split into only 2 groups (train and test),
then set val_pct to 0.
"""

# Check inputs and raise errors if needed
assert 0 < float(train_pct) < 1, "train_pct must be between 0 and 1"
assert 0 < float(test_pct) < 1, "test_pct must be between 0 and 1"
# check that the sum of train_pct, test_pct, and val_pct is equal to 1
assert (
round(train_pct + test_pct + val_pct, 1) == 1
), "Sum of train_pct, test_pct, and val_pct must equal 1."

df_main = self.dataset.df
gss = sklearnGroupShuffleSplit(
n_splits=1, train_size=train_pct, random_state=random_state
Expand Down Expand Up @@ -69,6 +79,22 @@ def StratifiedGroupShuffleSplit(
train, test, or val. When a split dataset is exported the annotations will be split into
seperate groups so that can be used used in model training, testing, and validation.
"""

# Check inputs and raise errors if needed
assert (
0 <= float(train_pct) <= 1
), "train_pct must be greater than or equal to 0 and less than or equal to 1"
assert (
0 <= float(test_pct) <= 1
), "test_pct must be greater than or equal to 0 and less than or equal to 1"
assert (
0 <= float(val_pct) <= 1
), "val_pct must be greater than or equal to 0 and less than or equal to 1"
# check that the sum of train_pct, test_pct, and val_pct is equal to 1
assert (
round(train_pct + test_pct + val_pct, 1) == 1
), "Sum of train_pct, test_pct, and val_pct must equal 1."

df_main = self.dataset.df
df_main = df_main.reindex(
np.random.permutation(df_main.index)
Expand Down

0 comments on commit db0bf7c

Please sign in to comment.