From 81c23cb7e7d6904c8b3e9d63fbdb8b07ab0b0dd4 Mon Sep 17 00:00:00 2001 From: Jerry Lin Date: Tue, 14 May 2024 12:06:40 -0700 Subject: [PATCH] save status quo name and feature when multiple status quo present Summary: When multiple status_quo is present in the data, we should still be able to save status quo name and features, which should be the same for all these status quos. Differential Revision: D57309789 --- ax/modelbridge/base.py | 47 +++++++++++++++++++++++------------------- 1 file changed, 26 insertions(+), 21 deletions(-) diff --git a/ax/modelbridge/base.py b/ax/modelbridge/base.py index f152a223302..704b7f83351 100644 --- a/ax/modelbridge/base.py +++ b/ax/modelbridge/base.py @@ -153,6 +153,8 @@ def __init__( self._optimization_config: Optional[OptimizationConfig] = optimization_config self._training_in_design: List[bool] = [] self._status_quo: Optional[Observation] = None + self._status_quo_name: Optional[str] = None + self._status_quo_features: Optional[ObservationFeatures] = None self._arms_by_signature: Optional[Dict[str, Arm]] = None self.transforms: MutableMapping[str, Transform] = OrderedDict() self._model_key: Optional[str] = None @@ -402,7 +404,7 @@ def _set_status_quo( status_quo_name: Optional[str], status_quo_features: Optional[ObservationFeatures], ) -> None: - """Set model status quo. + """Set model status quo by matching status_quo_name or status_quo_features. First checks for status quo in inputs status_quo_name and status_quo_features. If neither of these is provided, checks the @@ -415,6 +417,7 @@ def _set_status_quo( status_quo_features: Features for status quo. """ self._status_quo: Optional[Observation] = None + sq_obs = None if ( status_quo_name is None @@ -432,17 +435,6 @@ def _set_status_quo( sq_obs = [ obs for obs in self._training_data if obs.arm_name == status_quo_name ] - - if len(sq_obs) == 0: - logger.warning(f"Status quo {status_quo_name} not present in data") - elif len(sq_obs) > 1: - logger.warning( - f"Status quo {status_quo_name} found in data with multiple " - "features. Use status_quo_features to specify which to use." - ) - else: - self._status_quo = sq_obs[0] - elif status_quo_features is not None: sq_obs = [ obs @@ -451,14 +443,23 @@ def _set_status_quo( and (obs.features.trial_index == status_quo_features.trial_index) ] + # if something is used for matching + if sq_obs is not None: if len(sq_obs) == 0: - logger.warning( - f"Status quo features {status_quo_features} not found in data." - ) - else: - # len(sq_obs) will not be > 1, - # unique features verified in _set_training_data. - self._status_quo = sq_obs[0] + logger.warning(f"Status quo {status_quo_name} not present in data") + elif len(sq_obs) >= 1: + # status_quo_features and name should be consisnt even if we have + # multiple observations with the same features. + self._status_quo_name = sq_obs[0].arm_name + self._status_quo_features = sq_obs[0].features + if len(sq_obs) > 1: + logger.warning( + f"Status quo {status_quo_name} found in data with multiple " + "features. Use status_quo_features to specify which to use." + ) + else: + # if there is a unique status_quo, set it + self._status_quo = sq_obs[0] @property def status_quo_data_by_trial(self) -> Optional[Dict[int, ObservationData]]: @@ -466,10 +467,14 @@ def status_quo_data_by_trial(self) -> Optional[Dict[int, ObservationData]]: return _get_status_quo_by_trial( observations=self._training_data, status_quo_name=( - None if self._status_quo is None else self._status_quo.arm_name + self._status_quo_name + if self._status_quo is None + else self._status_quo.arm_name ), status_quo_features=( - None if self._status_quo is None else self._status_quo.features + self._status_quo_features + if self._status_quo is None + else self._status_quo.features ), )