From d79ac984283b801b0b18f661ffdd16bdccafaa71 Mon Sep 17 00:00:00 2001 From: Bernie Beckerman Date: Mon, 6 May 2024 07:15:16 -0700 Subject: [PATCH] Update tests so that data contains metric_names that are on the corresponding experiments (#2422) Summary: In D56634321, observations_from_dataframe fails if there are metric_names Data.df that don't also exist on the experiment. This adjusts tests so that they avoid this issue. Reviewed By: saitcakmak Differential Revision: D56850033 --- ax/core/tests/test_observation.py | 8 ++++++- ax/modelbridge/tests/test_alebo_strategy.py | 2 +- ax/modelbridge/tests/test_registry.py | 2 +- ax/plot/tests/test_fitted_scatter.py | 10 +++++---- ax/service/tests/scheduler_test_utils.py | 23 ++++++++++----------- ax/service/tests/test_ax_client.py | 1 - 6 files changed, 26 insertions(+), 20 deletions(-) diff --git a/ax/core/tests/test_observation.py b/ax/core/tests/test_observation.py index 3aded6b6031..e229c9d2d57 100644 --- a/ax/core/tests/test_observation.py +++ b/ax/core/tests/test_observation.py @@ -274,6 +274,7 @@ def test_ObservationsFromData(self) -> None: } type(experiment).arms_by_name = PropertyMock(return_value=arms) type(experiment).trials = PropertyMock(return_value=trials) + type(experiment).metrics = PropertyMock(return_value={"a": "a", "b": "b"}) df = pd.DataFrame(truth)[ ["arm_name", "trial_index", "mean", "sem", "metric_name"] @@ -362,6 +363,7 @@ def test_ObservationsFromDataWithFidelities(self) -> None: } type(experiment).arms_by_name = PropertyMock(return_value=arms) type(experiment).trials = PropertyMock(return_value=trials) + type(experiment).metrics = PropertyMock(return_value={"a": "a", "b": "b"}) df = pd.DataFrame(list(truth.values()))[ ["arm_name", "trial_index", "mean", "sem", "metric_name", "fidelities"] @@ -443,6 +445,7 @@ def test_ObservationsFromMapData(self) -> None: } type(experiment).arms_by_name = PropertyMock(return_value=arms) type(experiment).trials = PropertyMock(return_value=trials) + type(experiment).metrics = PropertyMock(return_value={"a": "a", "b": "b"}) df = pd.DataFrame(list(truth.values()))[ ["arm_name", "trial_index", "mean", "sem", "metric_name", "z", "timestamp"] @@ -559,6 +562,7 @@ def test_ObservationsFromDataAbandoned(self) -> None: trials.get(2).mark_arm_abandoned(arm_name="2_1") type(experiment).arms_by_name = PropertyMock(return_value=arms) type(experiment).trials = PropertyMock(return_value=trials) + type(experiment).metrics = PropertyMock(return_value={"a": "a", "b": "b"}) df = pd.DataFrame(list(truth.values()))[ ["arm_name", "trial_index", "mean", "sem", "metric_name"] @@ -635,6 +639,7 @@ def test_ObservationsFromDataWithSomeMissingTimes(self) -> None: } type(experiment).arms_by_name = PropertyMock(return_value=arms) type(experiment).trials = PropertyMock(return_value=trials) + type(experiment).metrics = PropertyMock(return_value={"a": "a", "b": "b"}) df = pd.DataFrame(truth)[ ["arm_name", "trial_index", "mean", "sem", "metric_name", "start_time"] @@ -719,7 +724,7 @@ def test_ObservationsFromDataWithDifferentTimesSingleTrial(self) -> None: } type(experiment).arms_by_name = PropertyMock(return_value=arms_by_name) type(experiment).trials = PropertyMock(return_value=trials) - + type(experiment).metrics = PropertyMock(return_value={"a": "a", "b": "b"}) df = pd.DataFrame(truth)[ [ "arm_name", @@ -856,6 +861,7 @@ def test_ObservationsWithCandidateMetadata(self) -> None: } type(experiment).arms_by_name = PropertyMock(return_value=arms) type(experiment).trials = PropertyMock(return_value=trials) + type(experiment).metrics = PropertyMock(return_value={"a": "a", "b": "b"}) df = pd.DataFrame(truth)[ ["arm_name", "trial_index", "mean", "sem", "metric_name"] diff --git a/ax/modelbridge/tests/test_alebo_strategy.py b/ax/modelbridge/tests/test_alebo_strategy.py index ac05197a9d2..eeade2314fd 100644 --- a/ax/modelbridge/tests/test_alebo_strategy.py +++ b/ax/modelbridge/tests/test_alebo_strategy.py @@ -32,7 +32,7 @@ def test_factory_functions(self) -> None: pd.DataFrame( { "arm_name": ["0_0", "0_1", "0_2"], - "metric_name": "y", + "metric_name": "branin", "mean": [-1.0, 0.0, 1.0], "sem": 0.1, } diff --git a/ax/modelbridge/tests/test_registry.py b/ax/modelbridge/tests/test_registry.py index 1d38c0b42ba..2cbd5260305 100644 --- a/ax/modelbridge/tests/test_registry.py +++ b/ax/modelbridge/tests/test_registry.py @@ -380,7 +380,7 @@ def test_ALEBO(self) -> None: pd.DataFrame( { "arm_name": ["0_0", "0_1", "0_2"], - "metric_name": "y", + "metric_name": "branin", "mean": [-1.0, 0.0, 1.0], "sem": 0.1, } diff --git a/ax/plot/tests/test_fitted_scatter.py b/ax/plot/tests/test_fitted_scatter.py index f8c6ad53c1b..6b33065c580 100644 --- a/ax/plot/tests/test_fitted_scatter.py +++ b/ax/plot/tests/test_fitted_scatter.py @@ -14,7 +14,7 @@ from ax.plot.base import AxPlotConfig from ax.plot.scatter import interact_fitted, interact_fitted_plotly from ax.utils.common.testutils import TestCase -from ax.utils.testing.core_stubs import get_branin_experiment +from ax.utils.testing.core_stubs import get_branin_experiment, get_branin_metric from ax.utils.testing.mock import fast_botorch_optimize @@ -27,6 +27,7 @@ def test_fitted_scatter(self) -> None: data = exp.fetch_data() df = deepcopy(data.df) df["metric_name"] = "branin_dup" + exp.add_tracking_metric(get_branin_metric(name="branin_dup")) model = Models.BOTORCH_MODULAR( # Model bridge kwargs @@ -47,11 +48,12 @@ def test_fitted_scatter(self) -> None: self.assertIsInstance(plot, AxPlotConfig) # Make sure all parameters and metrics are displayed in tooltips - tooltips = list(exp.parameters.keys()) + list(exp.metrics.keys()) - for d in plot.data["data"]: + metric_names = ["branin", "branin_dup", "branin:agg"] + tooltips = [list(exp.parameters.keys()) + [m_name] for m_name in metric_names] + for idata, d in enumerate(plot.data["data"]): # Only check scatter plots hoverovers if d["type"] != "scatter": continue for text in d["text"]: - for tt in tooltips: + for tt in tooltips[idata]: self.assertTrue(tt in text) diff --git a/ax/service/tests/scheduler_test_utils.py b/ax/service/tests/scheduler_test_utils.py index c8a4f4b8887..4a4d972c5d9 100644 --- a/ax/service/tests/scheduler_test_utils.py +++ b/ax/service/tests/scheduler_test_utils.py @@ -81,6 +81,7 @@ from sqlalchemy.orm.exc import StaleDataError DUMMY_EXCEPTION = "test_exception" +TEST_MEAN = 1.0 class SyntheticRunnerWithStatusPolling(SyntheticRunner): @@ -2147,28 +2148,26 @@ def test_it_does_not_overwrite_data_with_combine_fetch_kwarg(self) -> None: df=pd.DataFrame( { "arm_name": ["0_0"], - "metric_name": ["foo"], - "mean": [1.0], + "metric_name": ["branin"], + "mean": [TEST_MEAN], "sem": [0.1], "trial_index": [0], } ) ) ) - attached_metrics = ( - self.branin_experiment.lookup_data().df["metric_name"].unique() - ) + + attached_means = self.branin_experiment.lookup_data().df["mean"].unique() # the attach has overwritten the data, so we can infer that # fetching happened in the next `run_n_trials()` - self.assertNotIn("branin", attached_metrics) + self.assertIn(TEST_MEAN, attached_means) + self.assertEqual(len(attached_means), 1) scheduler.run_n_trials(max_trials=1) - attached_metrics = ( - self.branin_experiment.lookup_data().df["metric_name"].unique() - ) - # it did fetch again, but kept "foo" because of the combine kwarg - self.assertIn("foo", attached_metrics) - self.assertIn("branin", attached_metrics) + attached_means = self.branin_experiment.lookup_data().df["mean"].unique() + # it did fetch again, but kept both rows because of the combine kwarg + self.assertIn(TEST_MEAN, attached_means) + self.assertEqual(len(attached_means), 2) @fast_botorch_optimize def test_it_works_with_multitask_models( diff --git a/ax/service/tests/test_ax_client.py b/ax/service/tests/test_ax_client.py index 772c88e5196..34388798a73 100644 --- a/ax/service/tests/test_ax_client.py +++ b/ax/service/tests/test_ax_client.py @@ -1500,7 +1500,6 @@ def test_trial_completion(self) -> None: ax_client = get_branin_optimization() params, idx = ax_client.get_next_trial() ax_client.complete_trial(trial_index=idx, raw_data={"branin": (0, 0.0)}) - ax_client.update_trial_data(trial_index=idx, raw_data={"m1": (1, 0.0)}) metrics_in_data = ax_client.experiment.fetch_data().df["metric_name"].values self.assertNotIn("m1", metrics_in_data) self.assertIn("branin", metrics_in_data)