Skip to content

Commit

Permalink
save status quo name and feature when multiple status quo present
Browse files Browse the repository at this point in the history
Summary: When multiple status_quo is present in the data, we should still be able to save status quo name and features, which should be the same for all these status quos.

Differential Revision: D57309789
  • Loading branch information
ItsMrLin authored and facebook-github-bot committed May 14, 2024
1 parent ec9ed14 commit 81c23cb
Showing 1 changed file with 26 additions and 21 deletions.
47 changes: 26 additions & 21 deletions ax/modelbridge/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,8 @@ def __init__(
self._optimization_config: Optional[OptimizationConfig] = optimization_config
self._training_in_design: List[bool] = []
self._status_quo: Optional[Observation] = None
self._status_quo_name: Optional[str] = None
self._status_quo_features: Optional[ObservationFeatures] = None
self._arms_by_signature: Optional[Dict[str, Arm]] = None
self.transforms: MutableMapping[str, Transform] = OrderedDict()
self._model_key: Optional[str] = None
Expand Down Expand Up @@ -402,7 +404,7 @@ def _set_status_quo(
status_quo_name: Optional[str],
status_quo_features: Optional[ObservationFeatures],
) -> None:
"""Set model status quo.
"""Set model status quo by matching status_quo_name or status_quo_features.
First checks for status quo in inputs status_quo_name and
status_quo_features. If neither of these is provided, checks the
Expand All @@ -415,6 +417,7 @@ def _set_status_quo(
status_quo_features: Features for status quo.
"""
self._status_quo: Optional[Observation] = None
sq_obs = None

if (
status_quo_name is None
Expand All @@ -432,17 +435,6 @@ def _set_status_quo(
sq_obs = [
obs for obs in self._training_data if obs.arm_name == status_quo_name
]

if len(sq_obs) == 0:
logger.warning(f"Status quo {status_quo_name} not present in data")
elif len(sq_obs) > 1:
logger.warning(
f"Status quo {status_quo_name} found in data with multiple "
"features. Use status_quo_features to specify which to use."
)
else:
self._status_quo = sq_obs[0]

elif status_quo_features is not None:
sq_obs = [
obs
Expand All @@ -451,25 +443,38 @@ def _set_status_quo(
and (obs.features.trial_index == status_quo_features.trial_index)
]

# if something is used for matching
if sq_obs is not None:
if len(sq_obs) == 0:
logger.warning(
f"Status quo features {status_quo_features} not found in data."
)
else:
# len(sq_obs) will not be > 1,
# unique features verified in _set_training_data.
self._status_quo = sq_obs[0]
logger.warning(f"Status quo {status_quo_name} not present in data")
elif len(sq_obs) >= 1:
# status_quo_features and name should be consisnt even if we have
# multiple observations with the same features.
self._status_quo_name = sq_obs[0].arm_name
self._status_quo_features = sq_obs[0].features
if len(sq_obs) > 1:
logger.warning(
f"Status quo {status_quo_name} found in data with multiple "
"features. Use status_quo_features to specify which to use."
)
else:
# if there is a unique status_quo, set it
self._status_quo = sq_obs[0]

@property
def status_quo_data_by_trial(self) -> Optional[Dict[int, ObservationData]]:
"""A map of trial index to the status quo observation data of each trial"""
return _get_status_quo_by_trial(
observations=self._training_data,
status_quo_name=(
None if self._status_quo is None else self._status_quo.arm_name
self._status_quo_name
if self._status_quo is None
else self._status_quo.arm_name
),
status_quo_features=(
None if self._status_quo is None else self._status_quo.features
self._status_quo_features
if self._status_quo is None
else self._status_quo.features
),
)

Expand Down

0 comments on commit 81c23cb

Please sign in to comment.