Skip to content

Commit

Permalink
save status quo name when multiple status quos present in the data (#…
Browse files Browse the repository at this point in the history
…2457)

Summary:
Pull Request resolved: #2457

When multiple status_quo is present in the data, we should still be able to save status quo name and features, which should be the same for all these status quos.

Reviewed By: esantorella

Differential Revision: D57309789

fbshipit-source-id: be2037bde9bd8e711d7a60520610d4a5784c334f
  • Loading branch information
ItsMrLin authored and facebook-github-bot committed May 15, 2024
1 parent ac44a66 commit 6092ac3
Show file tree
Hide file tree
Showing 4 changed files with 105 additions and 29 deletions.
46 changes: 25 additions & 21 deletions ax/modelbridge/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@ def __init__(
self._optimization_config: Optional[OptimizationConfig] = optimization_config
self._training_in_design: List[bool] = []
self._status_quo: Optional[Observation] = None
self._status_quo_name: Optional[str] = None
self._arms_by_signature: Optional[Dict[str, Arm]] = None
self.transforms: MutableMapping[str, Transform] = OrderedDict()
self._model_key: Optional[str] = None
Expand Down Expand Up @@ -402,7 +403,7 @@ def _set_status_quo(
status_quo_name: Optional[str],
status_quo_features: Optional[ObservationFeatures],
) -> None:
"""Set model status quo.
"""Set model status quo by matching status_quo_name or status_quo_features.
First checks for status quo in inputs status_quo_name and
status_quo_features. If neither of these is provided, checks the
Expand All @@ -415,6 +416,7 @@ def _set_status_quo(
status_quo_features: Features for status quo.
"""
self._status_quo: Optional[Observation] = None
sq_obs = None

if (
status_quo_name is None
Expand All @@ -432,17 +434,6 @@ def _set_status_quo(
sq_obs = [
obs for obs in self._training_data if obs.arm_name == status_quo_name
]

if len(sq_obs) == 0:
logger.warning(f"Status quo {status_quo_name} not present in data")
elif len(sq_obs) > 1:
logger.warning(
f"Status quo {status_quo_name} found in data with multiple "
"features. Use status_quo_features to specify which to use."
)
else:
self._status_quo = sq_obs[0]

elif status_quo_features is not None:
sq_obs = [
obs
Expand All @@ -451,25 +442,38 @@ def _set_status_quo(
and (obs.features.trial_index == status_quo_features.trial_index)
]

# if status_quo_name or status_quo_features is used for matching status quo
if sq_obs is not None:
if len(sq_obs) == 0:
logger.warning(
f"Status quo features {status_quo_features} not found in data."
)
else:
# len(sq_obs) will not be > 1,
# unique features verified in _set_training_data.
self._status_quo = sq_obs[0]
logger.warning(f"Status quo {status_quo_name} not present in data")
elif len(sq_obs) >= 1:
# status quo name (not features as trial index is part of the
# observation features) should be consistent even if we have multiple
# observations of the status quo.
# This is useful for getting status_quo_data_by_trial
self._status_quo_name = sq_obs[0].arm_name
if len(sq_obs) > 1:
logger.warning(
f"Status quo {status_quo_name} found in data with multiple "
"features. Use status_quo_features to specify which to use."
)
else:
# if there is a unique status_quo, set it
# unique features verified in _set_training_data.
self._status_quo = sq_obs[0]

@property
def status_quo_data_by_trial(self) -> Optional[Dict[int, ObservationData]]:
"""A map of trial index to the status quo observation data of each trial"""
return _get_status_quo_by_trial(
observations=self._training_data,
status_quo_name=(
None if self._status_quo is None else self._status_quo.arm_name
self._status_quo_name
if self.status_quo is None
else self.status_quo.arm_name
),
status_quo_features=(
None if self._status_quo is None else self._status_quo.features
None if self.status_quo is None else self.status_quo.features
),
)

Expand Down
60 changes: 60 additions & 0 deletions ax/modelbridge/tests/test_base_modelbridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from ax.utils.common.constants import Keys
from ax.utils.common.logger import get_logger
from ax.utils.common.testutils import TestCase
from ax.utils.common.typeutils import not_none
from ax.utils.testing.core_stubs import (
get_branin_experiment,
get_branin_experiment_with_multi_objective,
Expand Down Expand Up @@ -769,6 +770,65 @@ def test_gen_on_experiment_with_imm_ss_and_opt_conf(self, _, __):
self.assertIsNone(gr.optimization_config)
self.assertIsNone(gr.search_space)

def test_set_status_quo(self) -> None:
# experiment with single status quo in trial
exp = get_branin_experiment(
with_batch=True,
with_status_quo=True,
num_batch_trial=1,
with_completed_batch=True,
)
modelbridge = ModelBridge(
search_space=exp.search_space,
experiment=exp,
model=Model,
data=exp.lookup_data(),
)

# we are able to set status_quo_data_by_trial when multiple
# status_quos present in each trial
self.assertIsNotNone(modelbridge.status_quo_data_by_trial)
# status_quo is set
self.assertIsNotNone(modelbridge.status_quo)
# Status quo name is logged
self.assertEqual(modelbridge._status_quo_name, not_none(exp.status_quo).name)

# experiment with multiple status quos in different trials
exp = get_branin_experiment(
with_batch=True,
with_status_quo=True,
num_batch_trial=2,
with_completed_batch=True,
)
modelbridge = ModelBridge(
search_space=exp.search_space,
experiment=exp,
model=Model,
data=exp.lookup_data(),
)
# we are able to set status_quo_data_by_trial when multiple
# status_quos present in each trial
self.assertIsNotNone(modelbridge.status_quo_data_by_trial)
# status_quo is not set
self.assertIsNone(modelbridge.status_quo)
# Status quo name can still be logged
self.assertEqual(modelbridge._status_quo_name, not_none(exp.status_quo).name)

# a unique status_quo can be identified (by trial index)
# if status_quo_features is specified
status_quo_features = ObservationFeatures(
parameters=not_none(exp.status_quo).parameters,
trial_index=0,
)
modelbridge = ModelBridge(
search_space=exp.search_space,
experiment=exp,
model=Model,
data=exp.lookup_data(),
status_quo_features=status_quo_features,
)
self.assertIsNotNone(modelbridge.status_quo)


class testClampObservationFeatures(TestCase):
def test_ClampObservationFeaturesNearBounds(self) -> None:
Expand Down
20 changes: 12 additions & 8 deletions ax/utils/testing/core_stubs.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,8 @@ def get_branin_experiment(
search_space: Optional[SearchSpace] = None,
minimize: bool = False,
named: bool = True,
num_batch_trial: int = 1,
with_completed_batch: bool = False,
with_completed_trial: bool = False,
) -> Experiment:
search_space = search_space or get_branin_search_space(
Expand All @@ -267,17 +269,19 @@ def get_branin_experiment(
),
runner=SyntheticRunner(),
is_test=True,
status_quo=Arm(parameters={"x1": 0.0, "x2": 0.0}) if with_status_quo else None,
)

if with_status_quo:
exp.status_quo = Arm(parameters={"x1": 0.0, "x2": 0.0})

if with_batch:
sobol_generator = get_sobol(search_space=exp.search_space)
sobol_run = sobol_generator.gen(n=15)
exp.new_batch_trial(optimize_for_power=with_status_quo).add_generator_run(
sobol_run
)
for _ in range(num_batch_trial):
sobol_generator = get_sobol(search_space=exp.search_space)
sobol_run = sobol_generator.gen(n=15)
trial = exp.new_batch_trial(optimize_for_power=with_status_quo)
trial.add_generator_run(sobol_run)
if with_completed_batch:
trial.mark_running(no_runner_required=True)
exp.attach_data(get_branin_data_batch(batch=trial))
trial.mark_completed()

if with_trial or with_completed_trial:
sobol_generator = get_sobol(search_space=exp.search_space)
Expand Down
8 changes: 8 additions & 0 deletions sphinx/source/utils.rst
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,14 @@ Modeling Stubs
:undoc-members:
:show-inheritance:

Preference Stubs
~~~~~~~~~~~~~~

.. automodule:: ax.utils.testing.preference_stubs
:members:
:undoc-members:
:show-inheritance:


Mocking
~~~~~~~
Expand Down

0 comments on commit 6092ac3

Please sign in to comment.