Skip to content

Commit

Permalink
Merge pull request #45 from AndreFCruz/fix-test-seed
Browse files Browse the repository at this point in the history
Fixed random seed for flaky test
  • Loading branch information
mrtzh committed May 6, 2024
2 parents bb79df0 + 84cf670 commit d78adc5
Showing 1 changed file with 7 additions and 3 deletions.
10 changes: 7 additions & 3 deletions tests/test_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@

from folktables import ACSDataSource, ACSEmployment

SEED = 0

def test_download():
data_source = ACSDataSource(survey_year='2018', horizon='1-Year', survey='person')
acs_data = data_source.get_data(states=["CA"], download=True)
Expand All @@ -22,17 +24,19 @@ def test_train():
features, label, group = ACSEmployment.df_to_numpy(acs_data)

X_train, X_test, y_train, y_test, group_train, group_test = train_test_split(
features, label, group, test_size=0.2, random_state=0)
features, label, group, test_size=0.2, random_state=SEED)

model = make_pipeline(StandardScaler(), LogisticRegression())
model = make_pipeline(StandardScaler(), LogisticRegression(random_state=SEED))
model.fit(X_train, y_train)

yhat = model.predict(X_test)

white_tpr = np.mean(yhat[(y_test == 1) & (group_test == 1)])
black_tpr = np.mean(yhat[(y_test == 1) & (group_test == 2)])

assert np.allclose(white_tpr - black_tpr, 0.04549392964278809)
ref_value = 0.04490694888127078
assert np.allclose(white_tpr - black_tpr, ref_value, atol=1e-2), \
f"Expected {ref_value}, got {white_tpr - black_tpr}"

if __name__ == "__main__":
test_download()
Expand Down

0 comments on commit d78adc5

Please sign in to comment.