Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update tests so that data contains metric_names that are on the corresponding experiments #2422

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion ax/core/tests/test_observation.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,7 @@ def test_ObservationsFromData(self) -> None:
}
type(experiment).arms_by_name = PropertyMock(return_value=arms)
type(experiment).trials = PropertyMock(return_value=trials)
type(experiment).metrics = PropertyMock(return_value={"a": "a", "b": "b"})

df = pd.DataFrame(truth)[
["arm_name", "trial_index", "mean", "sem", "metric_name"]
Expand Down Expand Up @@ -362,6 +363,7 @@ def test_ObservationsFromDataWithFidelities(self) -> None:
}
type(experiment).arms_by_name = PropertyMock(return_value=arms)
type(experiment).trials = PropertyMock(return_value=trials)
type(experiment).metrics = PropertyMock(return_value={"a": "a", "b": "b"})

df = pd.DataFrame(list(truth.values()))[
["arm_name", "trial_index", "mean", "sem", "metric_name", "fidelities"]
Expand Down Expand Up @@ -443,6 +445,7 @@ def test_ObservationsFromMapData(self) -> None:
}
type(experiment).arms_by_name = PropertyMock(return_value=arms)
type(experiment).trials = PropertyMock(return_value=trials)
type(experiment).metrics = PropertyMock(return_value={"a": "a", "b": "b"})

df = pd.DataFrame(list(truth.values()))[
["arm_name", "trial_index", "mean", "sem", "metric_name", "z", "timestamp"]
Expand Down Expand Up @@ -559,6 +562,7 @@ def test_ObservationsFromDataAbandoned(self) -> None:
trials.get(2).mark_arm_abandoned(arm_name="2_1")
type(experiment).arms_by_name = PropertyMock(return_value=arms)
type(experiment).trials = PropertyMock(return_value=trials)
type(experiment).metrics = PropertyMock(return_value={"a": "a", "b": "b"})

df = pd.DataFrame(list(truth.values()))[
["arm_name", "trial_index", "mean", "sem", "metric_name"]
Expand Down Expand Up @@ -635,6 +639,7 @@ def test_ObservationsFromDataWithSomeMissingTimes(self) -> None:
}
type(experiment).arms_by_name = PropertyMock(return_value=arms)
type(experiment).trials = PropertyMock(return_value=trials)
type(experiment).metrics = PropertyMock(return_value={"a": "a", "b": "b"})

df = pd.DataFrame(truth)[
["arm_name", "trial_index", "mean", "sem", "metric_name", "start_time"]
Expand Down Expand Up @@ -719,7 +724,7 @@ def test_ObservationsFromDataWithDifferentTimesSingleTrial(self) -> None:
}
type(experiment).arms_by_name = PropertyMock(return_value=arms_by_name)
type(experiment).trials = PropertyMock(return_value=trials)

type(experiment).metrics = PropertyMock(return_value={"a": "a", "b": "b"})
df = pd.DataFrame(truth)[
[
"arm_name",
Expand Down Expand Up @@ -856,6 +861,7 @@ def test_ObservationsWithCandidateMetadata(self) -> None:
}
type(experiment).arms_by_name = PropertyMock(return_value=arms)
type(experiment).trials = PropertyMock(return_value=trials)
type(experiment).metrics = PropertyMock(return_value={"a": "a", "b": "b"})

df = pd.DataFrame(truth)[
["arm_name", "trial_index", "mean", "sem", "metric_name"]
Expand Down
2 changes: 1 addition & 1 deletion ax/modelbridge/tests/test_alebo_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def test_factory_functions(self) -> None:
pd.DataFrame(
{
"arm_name": ["0_0", "0_1", "0_2"],
"metric_name": "y",
"metric_name": "branin",
"mean": [-1.0, 0.0, 1.0],
"sem": 0.1,
}
Expand Down
2 changes: 1 addition & 1 deletion ax/modelbridge/tests/test_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,7 +380,7 @@ def test_ALEBO(self) -> None:
pd.DataFrame(
{
"arm_name": ["0_0", "0_1", "0_2"],
"metric_name": "y",
"metric_name": "branin",
"mean": [-1.0, 0.0, 1.0],
"sem": 0.1,
}
Expand Down
10 changes: 6 additions & 4 deletions ax/plot/tests/test_fitted_scatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from ax.plot.base import AxPlotConfig
from ax.plot.scatter import interact_fitted, interact_fitted_plotly
from ax.utils.common.testutils import TestCase
from ax.utils.testing.core_stubs import get_branin_experiment
from ax.utils.testing.core_stubs import get_branin_experiment, get_branin_metric
from ax.utils.testing.mock import fast_botorch_optimize


Expand All @@ -27,6 +27,7 @@ def test_fitted_scatter(self) -> None:
data = exp.fetch_data()
df = deepcopy(data.df)
df["metric_name"] = "branin_dup"
exp.add_tracking_metric(get_branin_metric(name="branin_dup"))

model = Models.BOTORCH_MODULAR(
# Model bridge kwargs
Expand All @@ -47,11 +48,12 @@ def test_fitted_scatter(self) -> None:
self.assertIsInstance(plot, AxPlotConfig)

# Make sure all parameters and metrics are displayed in tooltips
tooltips = list(exp.parameters.keys()) + list(exp.metrics.keys())
for d in plot.data["data"]:
metric_names = ["branin", "branin_dup", "branin:agg"]
tooltips = [list(exp.parameters.keys()) + [m_name] for m_name in metric_names]
for idata, d in enumerate(plot.data["data"]):
# Only check scatter plots hoverovers
if d["type"] != "scatter":
continue
for text in d["text"]:
for tt in tooltips:
for tt in tooltips[idata]:
self.assertTrue(tt in text)
23 changes: 11 additions & 12 deletions ax/service/tests/scheduler_test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@
from sqlalchemy.orm.exc import StaleDataError

DUMMY_EXCEPTION = "test_exception"
TEST_MEAN = 1.0


class SyntheticRunnerWithStatusPolling(SyntheticRunner):
Expand Down Expand Up @@ -2147,28 +2148,26 @@ def test_it_does_not_overwrite_data_with_combine_fetch_kwarg(self) -> None:
df=pd.DataFrame(
{
"arm_name": ["0_0"],
"metric_name": ["foo"],
"mean": [1.0],
"metric_name": ["branin"],
"mean": [TEST_MEAN],
"sem": [0.1],
"trial_index": [0],
}
)
)
)
attached_metrics = (
self.branin_experiment.lookup_data().df["metric_name"].unique()
)

attached_means = self.branin_experiment.lookup_data().df["mean"].unique()
# the attach has overwritten the data, so we can infer that
# fetching happened in the next `run_n_trials()`
self.assertNotIn("branin", attached_metrics)
self.assertIn(TEST_MEAN, attached_means)
self.assertEqual(len(attached_means), 1)

scheduler.run_n_trials(max_trials=1)
attached_metrics = (
self.branin_experiment.lookup_data().df["metric_name"].unique()
)
# it did fetch again, but kept "foo" because of the combine kwarg
self.assertIn("foo", attached_metrics)
self.assertIn("branin", attached_metrics)
attached_means = self.branin_experiment.lookup_data().df["mean"].unique()
# it did fetch again, but kept both rows because of the combine kwarg
self.assertIn(TEST_MEAN, attached_means)
self.assertEqual(len(attached_means), 2)

@fast_botorch_optimize
def test_it_works_with_multitask_models(
Expand Down
1 change: 0 additions & 1 deletion ax/service/tests/test_ax_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1500,7 +1500,6 @@ def test_trial_completion(self) -> None:
ax_client = get_branin_optimization()
params, idx = ax_client.get_next_trial()
ax_client.complete_trial(trial_index=idx, raw_data={"branin": (0, 0.0)})
ax_client.update_trial_data(trial_index=idx, raw_data={"m1": (1, 0.0)})
metrics_in_data = ax_client.experiment.fetch_data().df["metric_name"].values
self.assertNotIn("m1", metrics_in_data)
self.assertIn("branin", metrics_in_data)
Expand Down
Loading