Skip to content
This repository has been archived by the owner on Aug 27, 2024. It is now read-only.

Commit

Permalink
Merge pull request #670 from predictive-analytics-lab/parallel-pytest
Browse files Browse the repository at this point in the history
Parallel pytest
  • Loading branch information
tmke8 committed May 28, 2022
2 parents 3e2d623 + 6858404 commit 3953dc4
Show file tree
Hide file tree
Showing 8 changed files with 99 additions and 26 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/continuous_integration.yml
Original file line number Diff line number Diff line change
Expand Up @@ -144,4 +144,4 @@ jobs:
#----------------------------------------------
- name: Test with pytest
run: |
poetry run python -m pytest -vv --cov=ethicml --cov-fail-under=80 tests/
poetry run python -m pytest -vv -n 2 --dist loadgroup --cov=ethicml --cov-fail-under=80 tests/
103 changes: 78 additions & 25 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ pytest-cov = ">=2.6,<4.0"
python-type-stubs = {git = "https://github.com/wearepal/python-type-stubs.git", rev = "316dad4"}
types-Pillow = "^9.0.15"
omegaconf = ">=2.2.1"
pytest-xdist = "^2.5.0"

[tool.black]
line-length = 100
Expand Down
3 changes: 3 additions & 0 deletions tests/models_test/inprocess_test/models_inprocessing_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ def test_kamiran_weights(toy_train_test: TrainTestPair):


@pytest.mark.parametrize("name,model,num_pos", INPROCESS_TESTS)
@pytest.mark.xdist_group("in_model_files")
def test_inprocess_sep_train_pred(
toy_train_val: TrainValPair, name: str, model: InAlgorithm, num_pos: int
):
Expand Down Expand Up @@ -217,6 +218,7 @@ def kamishima_gen() -> Generator[Kamishima, None, None]:


@pytest.mark.slow
@pytest.mark.xdist_group("in_model_files")
def test_kamishima(toy_train_test: TrainTestPair, kamishima_gen: Kamishima) -> None:
"""Test Kamishima."""
train, test = toy_train_test
Expand Down Expand Up @@ -266,6 +268,7 @@ def get_hyperparameters(self) -> HyperParamType:


@pytest.mark.slow
@pytest.mark.xdist_group("results_files")
def test_threaded_agarwal():
"""Test threaded agarwal."""
models: List[InAlgorithmSubprocess] = [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def test_post(
PostprocessTest(post_model=Hardt(), name="Hardt", num_pos=35),
],
)
@pytest.mark.xdist_group("post_model_files")
def test_post_sep_fit_pred(
toy_train_val: TrainValPair, post_model: PostAlgorithm, name: str, num_pos: int
) -> None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ def test_pre(toy_train_test: TrainTestPair, model: PreAlgorithm, name: str, num_


@pytest.mark.parametrize("model,name,num_pos", METHOD_LIST)
@pytest.mark.xdist_group("pre_model_files")
def test_pre_sep_fit_transform(
toy_train_val: TrainValPair, model: PreAlgorithm, name: str, num_pos: int
):
Expand Down
8 changes: 8 additions & 0 deletions tests/run_algorithm_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def test_run_parallel(toy_train_val: em.TrainValPair):


@pytest.mark.usefixtures("results_cleanup")
@pytest.mark.xdist_group("results_files")
def test_empty_evaluate():
"""Test empty evaluate."""
empty_result = em.evaluate_models([em.Toy()], repeats=3)
Expand All @@ -67,6 +68,7 @@ def test_empty_evaluate():

@pytest.mark.parametrize("repeats", [1, 2, 3])
@pytest.mark.usefixtures("results_cleanup")
@pytest.mark.xdist_group("results_files")
def test_run_alg_repeats_error(repeats: int):
"""Add a test to check that the right number of reults are produced with repeats."""
dataset = em.Adult(split=em.Adult.Splits.RACE_BINARY)
Expand All @@ -92,6 +94,7 @@ def test_run_alg_repeats_error(repeats: int):
@pytest.mark.parametrize("on", ["data", "model", "both"])
@pytest.mark.parametrize("repeats", [2, 3, 5])
@pytest.mark.usefixtures("results_cleanup")
@pytest.mark.xdist_group("results_files")
def test_run_repeats(repeats: int, on: Literal["data", "model", "both"]):
"""Check the repeat_on arg."""
dataset = em.Adult(split=em.Adult.Splits.RACE_BINARY)
Expand Down Expand Up @@ -124,6 +127,7 @@ def test_run_repeats(repeats: int, on: Literal["data", "model", "both"]):


@pytest.mark.usefixtures("results_cleanup")
@pytest.mark.xdist_group("results_files")
def test_run_alg_suite_scaler():
"""Test run alg suite."""
dataset = em.Adult(split=em.Adult.Splits.RACE_BINARY)
Expand Down Expand Up @@ -160,6 +164,7 @@ def test_run_alg_suite_scaler():


@pytest.mark.usefixtures("results_cleanup")
@pytest.mark.xdist_group("results_files")
def test_run_alg_suite():
"""Test run alg suite."""
dataset = em.Adult(split=em.Adult.Splits.RACE_BINARY)
Expand Down Expand Up @@ -202,6 +207,7 @@ def test_run_alg_suite():


@pytest.mark.usefixtures("results_cleanup")
@pytest.mark.xdist_group("results_files")
def test_run_alg_suite_wrong_metrics():
"""Test run alg suite wrong metrics."""
datasets: List[em.Dataset] = [em.Toy(), em.Adult()]
Expand All @@ -223,6 +229,7 @@ def test_run_alg_suite_wrong_metrics():


@pytest.mark.usefixtures("results_cleanup")
@pytest.mark.xdist_group("results_files")
def test_run_alg_suite_err_handling():
"""Test run alg suite handles when an err is thrown."""

Expand Down Expand Up @@ -262,6 +269,7 @@ def get_name(self) -> str:

@pytest.mark.slow
@pytest.mark.usefixtures("results_cleanup")
@pytest.mark.xdist_group("results_files")
def test_run_alg_suite_no_pipeline():
"""Run alg suite while avoiding the 'fair pipeline'."""
datasets: List[em.Dataset] = [em.Toy(), em.Adult()]
Expand Down
6 changes: 6 additions & 0 deletions tests/visualisation_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,15 @@

@pytest.mark.slow
@pytest.mark.usefixtures("plot_cleanup") # fixtures are defined in `tests/conftest.py`
@pytest.mark.xdist_group("results_files")
def test_plot_tsne(toy_train_test: TrainTestPair):
"""Test plot."""
train, _ = toy_train_test
save_2d_plot(train, "./plots/test.png")


@pytest.mark.usefixtures("plot_cleanup") # fixtures are defined in `tests/conftest.py`
@pytest.mark.xdist_group("results_files")
def test_plot_no_tsne(toy_train_test: TrainTestPair):
"""Test plot."""
train, _ = toy_train_test
Expand All @@ -49,6 +51,7 @@ def test_plot_no_tsne(toy_train_test: TrainTestPair):


@pytest.mark.usefixtures("plot_cleanup")
@pytest.mark.xdist_group("results_files")
def test_joint_plot(toy_train_test: TrainTestPair):
"""Test joint plot."""
train, _ = toy_train_test
Expand All @@ -57,13 +60,15 @@ def test_joint_plot(toy_train_test: TrainTestPair):

@pytest.mark.slow
@pytest.mark.usefixtures("plot_cleanup")
@pytest.mark.xdist_group("results_files")
def test_multijoint_plot(toy_train_test: TrainTestPair):
"""Test joint plot."""
train, _ = toy_train_test
save_multijointplot(train, "./plots/joint.png")


@pytest.mark.usefixtures("plot_cleanup")
@pytest.mark.xdist_group("results_files")
def test_label_plot():
"""Test label plot."""
data: DataTuple = load_data(Adult())
Expand All @@ -74,6 +79,7 @@ def test_label_plot():


@pytest.mark.usefixtures("plot_cleanup")
@pytest.mark.xdist_group("results_files")
def test_plot_evals():
"""Test plot evals."""
results: Results = evaluate_models(
Expand Down

0 comments on commit 3953dc4

Please sign in to comment.