Skip to content

Commit

Permalink
Update tests so that data contains metric_names that are on the corre…
Browse files Browse the repository at this point in the history
…sponding experiments

Summary: In D56634321, observations_from_dataframe fails if there are metric_names Data.df that don't also exist on the experiment. This adjusts tests so that they avoid this issue.

Differential Revision: D56850033
  • Loading branch information
Bernie Beckerman authored and facebook-github-bot committed May 3, 2024
1 parent f007848 commit 123f3bf
Show file tree
Hide file tree
Showing 6 changed files with 26 additions and 20 deletions.
8 changes: 7 additions & 1 deletion ax/core/tests/test_observation.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,7 @@ def test_ObservationsFromData(self) -> None:
}
type(experiment).arms_by_name = PropertyMock(return_value=arms)
type(experiment).trials = PropertyMock(return_value=trials)
type(experiment).metrics = PropertyMock(return_value={"a": "a", "b": "b"})

df = pd.DataFrame(truth)[
["arm_name", "trial_index", "mean", "sem", "metric_name"]
Expand Down Expand Up @@ -362,6 +363,7 @@ def test_ObservationsFromDataWithFidelities(self) -> None:
}
type(experiment).arms_by_name = PropertyMock(return_value=arms)
type(experiment).trials = PropertyMock(return_value=trials)
type(experiment).metrics = PropertyMock(return_value={"a": "a", "b": "b"})

df = pd.DataFrame(list(truth.values()))[
["arm_name", "trial_index", "mean", "sem", "metric_name", "fidelities"]
Expand Down Expand Up @@ -443,6 +445,7 @@ def test_ObservationsFromMapData(self) -> None:
}
type(experiment).arms_by_name = PropertyMock(return_value=arms)
type(experiment).trials = PropertyMock(return_value=trials)
type(experiment).metrics = PropertyMock(return_value={"a": "a", "b": "b"})

df = pd.DataFrame(list(truth.values()))[
["arm_name", "trial_index", "mean", "sem", "metric_name", "z", "timestamp"]
Expand Down Expand Up @@ -559,6 +562,7 @@ def test_ObservationsFromDataAbandoned(self) -> None:
trials.get(2).mark_arm_abandoned(arm_name="2_1")
type(experiment).arms_by_name = PropertyMock(return_value=arms)
type(experiment).trials = PropertyMock(return_value=trials)
type(experiment).metrics = PropertyMock(return_value={"a": "a", "b": "b"})

df = pd.DataFrame(list(truth.values()))[
["arm_name", "trial_index", "mean", "sem", "metric_name"]
Expand Down Expand Up @@ -635,6 +639,7 @@ def test_ObservationsFromDataWithSomeMissingTimes(self) -> None:
}
type(experiment).arms_by_name = PropertyMock(return_value=arms)
type(experiment).trials = PropertyMock(return_value=trials)
type(experiment).metrics = PropertyMock(return_value={"a": "a", "b": "b"})

df = pd.DataFrame(truth)[
["arm_name", "trial_index", "mean", "sem", "metric_name", "start_time"]
Expand Down Expand Up @@ -719,7 +724,7 @@ def test_ObservationsFromDataWithDifferentTimesSingleTrial(self) -> None:
}
type(experiment).arms_by_name = PropertyMock(return_value=arms_by_name)
type(experiment).trials = PropertyMock(return_value=trials)

type(experiment).metrics = PropertyMock(return_value={"a": "a", "b": "b"})
df = pd.DataFrame(truth)[
[
"arm_name",
Expand Down Expand Up @@ -856,6 +861,7 @@ def test_ObservationsWithCandidateMetadata(self) -> None:
}
type(experiment).arms_by_name = PropertyMock(return_value=arms)
type(experiment).trials = PropertyMock(return_value=trials)
type(experiment).metrics = PropertyMock(return_value={"a": "a", "b": "b"})

df = pd.DataFrame(truth)[
["arm_name", "trial_index", "mean", "sem", "metric_name"]
Expand Down
2 changes: 1 addition & 1 deletion ax/modelbridge/tests/test_alebo_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def test_factory_functions(self) -> None:
pd.DataFrame(
{
"arm_name": ["0_0", "0_1", "0_2"],
"metric_name": "y",
"metric_name": "branin",
"mean": [-1.0, 0.0, 1.0],
"sem": 0.1,
}
Expand Down
2 changes: 1 addition & 1 deletion ax/modelbridge/tests/test_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,7 +380,7 @@ def test_ALEBO(self) -> None:
pd.DataFrame(
{
"arm_name": ["0_0", "0_1", "0_2"],
"metric_name": "y",
"metric_name": "branin",
"mean": [-1.0, 0.0, 1.0],
"sem": 0.1,
}
Expand Down
10 changes: 6 additions & 4 deletions ax/plot/tests/test_fitted_scatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from ax.plot.base import AxPlotConfig
from ax.plot.scatter import interact_fitted, interact_fitted_plotly
from ax.utils.common.testutils import TestCase
from ax.utils.testing.core_stubs import get_branin_experiment
from ax.utils.testing.core_stubs import get_branin_experiment, get_branin_metric
from ax.utils.testing.mock import fast_botorch_optimize


Expand All @@ -27,6 +27,7 @@ def test_fitted_scatter(self) -> None:
data = exp.fetch_data()
df = deepcopy(data.df)
df["metric_name"] = "branin_dup"
exp.add_tracking_metric(get_branin_metric(name="branin_dup"))

model = Models.BOTORCH_MODULAR(
# Model bridge kwargs
Expand All @@ -47,11 +48,12 @@ def test_fitted_scatter(self) -> None:
self.assertIsInstance(plot, AxPlotConfig)

# Make sure all parameters and metrics are displayed in tooltips
tooltips = list(exp.parameters.keys()) + list(exp.metrics.keys())
for d in plot.data["data"]:
metric_names = ["branin", "branin_dup", "branin:agg"]
tooltips = [list(exp.parameters.keys()) + [m_name] for m_name in metric_names]
for idata, d in enumerate(plot.data["data"]):
# Only check scatter plots hoverovers
if d["type"] != "scatter":
continue
for text in d["text"]:
for tt in tooltips:
for tt in tooltips[idata]:
self.assertTrue(tt in text)
23 changes: 11 additions & 12 deletions ax/service/tests/scheduler_test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@
from sqlalchemy.orm.exc import StaleDataError

DUMMY_EXCEPTION = "test_exception"
TEST_MEAN = 1.0


class SyntheticRunnerWithStatusPolling(SyntheticRunner):
Expand Down Expand Up @@ -2147,28 +2148,26 @@ def test_it_does_not_overwrite_data_with_combine_fetch_kwarg(self) -> None:
df=pd.DataFrame(
{
"arm_name": ["0_0"],
"metric_name": ["foo"],
"mean": [1.0],
"metric_name": ["branin"],
"mean": [TEST_MEAN],
"sem": [0.1],
"trial_index": [0],
}
)
)
)
attached_metrics = (
self.branin_experiment.lookup_data().df["metric_name"].unique()
)

attached_means = self.branin_experiment.lookup_data().df["mean"].unique()
# the attach has overwritten the data, so we can infer that
# fetching happened in the next `run_n_trials()`
self.assertNotIn("branin", attached_metrics)
self.assertIn(TEST_MEAN, attached_means)
self.assertEqual(len(attached_means), 1)

scheduler.run_n_trials(max_trials=1)
attached_metrics = (
self.branin_experiment.lookup_data().df["metric_name"].unique()
)
# it did fetch again, but kept "foo" because of the combine kwarg
self.assertIn("foo", attached_metrics)
self.assertIn("branin", attached_metrics)
attached_means = self.branin_experiment.lookup_data().df["mean"].unique()
# it did fetch again, but kept both rows because of the combine kwarg
self.assertIn(TEST_MEAN, attached_means)
self.assertEqual(len(attached_means), 2)

@fast_botorch_optimize
def test_it_works_with_multitask_models(
Expand Down
1 change: 0 additions & 1 deletion ax/service/tests/test_ax_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1500,7 +1500,6 @@ def test_trial_completion(self) -> None:
ax_client = get_branin_optimization()
params, idx = ax_client.get_next_trial()
ax_client.complete_trial(trial_index=idx, raw_data={"branin": (0, 0.0)})
ax_client.update_trial_data(trial_index=idx, raw_data={"m1": (1, 0.0)})
metrics_in_data = ax_client.experiment.fetch_data().df["metric_name"].values
self.assertNotIn("m1", metrics_in_data)
self.assertIn("branin", metrics_in_data)
Expand Down

0 comments on commit 123f3bf

Please sign in to comment.