Skip to content

Commit

Permalink
chore: fix mp.pool test_streaming_metrics_api (#8917)
Browse files Browse the repository at this point in the history
(cherry picked from commit ad7d260)
  • Loading branch information
NicholasBlaskey authored and determined-ci committed Feb 29, 2024
1 parent 18e2ea4 commit 435e90a
Showing 1 changed file with 21 additions and 9 deletions.
30 changes: 21 additions & 9 deletions e2e_tests/tests/experiment/test_metrics.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import json
import multiprocessing as mp
from multiprocessing import pool
from typing import Dict, List, Set, Union

import pytest
Expand All @@ -15,7 +15,7 @@
@pytest.mark.timeout(600)
def test_streaming_metrics_api() -> None:
sess = api_utils.user_session()
pool = mp.pool.ThreadPool(processes=7)
thread_pool = pool.ThreadPool(processes=7)

experiment_id = exp.create_experiment(
sess,
Expand All @@ -29,13 +29,25 @@ def test_streaming_metrics_api() -> None:
# The HP importance portion of this test is commented out until the feature is enabled by
# default

metric_names_thread = pool.apply_async(request_metric_names, (experiment_id,))
train_metric_batches_thread = pool.apply_async(request_train_metric_batches, (experiment_id,))
valid_metric_batches_thread = pool.apply_async(request_valid_metric_batches, (experiment_id,))
train_trials_snapshot_thread = pool.apply_async(request_train_trials_snapshot, (experiment_id,))
valid_trials_snapshot_thread = pool.apply_async(request_valid_trials_snapshot, (experiment_id,))
train_trials_sample_thread = pool.apply_async(request_train_trials_sample, (experiment_id,))
valid_trials_sample_thread = pool.apply_async(request_valid_trials_sample, (experiment_id,))
metric_names_thread = thread_pool.apply_async(request_metric_names, (experiment_id,))
train_metric_batches_thread = thread_pool.apply_async(
request_train_metric_batches, (experiment_id,)
)
valid_metric_batches_thread = thread_pool.apply_async(
request_valid_metric_batches, (experiment_id,)
)
train_trials_snapshot_thread = thread_pool.apply_async(
request_train_trials_snapshot, (experiment_id,)
)
valid_trials_snapshot_thread = thread_pool.apply_async(
request_valid_trials_snapshot, (experiment_id,)
)
train_trials_sample_thread = thread_pool.apply_async(
request_train_trials_sample, (experiment_id,)
)
valid_trials_sample_thread = thread_pool.apply_async(
request_valid_trials_sample, (experiment_id,)
)

metric_names_results = metric_names_thread.get()
train_metric_batches_results = train_metric_batches_thread.get()
Expand Down

0 comments on commit 435e90a

Please sign in to comment.