diff --git a/ax/models/tests/test_fully_bayesian.py b/ax/models/tests/test_fully_bayesian.py deleted file mode 100644 index 16f21667421..00000000000 --- a/ax/models/tests/test_fully_bayesian.py +++ /dev/null @@ -1,1161 +0,0 @@ -#!/usr/bin/env python3 -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. - -# pyre-strict - -import dataclasses -from contextlib import ExitStack -from itertools import count, product -from logging import Logger -from math import sqrt -from typing import Any, cast, Dict, Type -from unittest import mock - -import pyro -import torch -from ax.core.search_space import SearchSpaceDigest -from ax.exceptions.core import AxError -from ax.models.torch.botorch import BotorchModel -from ax.models.torch.fully_bayesian import ( - FullyBayesianBotorchModel, - FullyBayesianMOOBotorchModel, - matern_kernel, - rbf_kernel, - single_task_pyro_model, -) -from ax.models.torch_base import TorchOptConfig -from ax.utils.common.constants import Keys -from ax.utils.common.logger import get_logger -from ax.utils.common.random import set_rng_seed, with_rng_seed -from ax.utils.common.testutils import TestCase -from ax.utils.testing.torch_stubs import get_torch_test_data -from botorch.acquisition.utils import get_infeasible_cost -from botorch.models import ModelListGP -from botorch.models.gp_regression import MIN_INFERRED_NOISE_LEVEL -from botorch.models.model import ModelList -from botorch.models.transforms.input import Warp -from botorch.optim.optimize import optimize_acqf -from botorch.posteriors.gpytorch import GPyTorchPosterior -from botorch.utils import get_objective_weights_transform -from botorch.utils.datasets import SupervisedDataset -from gpytorch.constraints import GreaterThan, Positive -from gpytorch.kernels import MaternKernel, RBFKernel, ScaleKernel -from gpytorch.likelihoods import _GaussianLikelihoodBase -from pyro.infer.mcmc import MCMC, NUTS -from pyro.ops.integrator import potential_grad - - -RUN_INFERENCE_PATH = "ax.models.torch.fully_bayesian.run_inference" -NUTS_PATH = "pyro.infer.mcmc.NUTS" -MCMC_PATH = "pyro.infer.mcmc.MCMC" - -logger: Logger = get_logger(__name__) - - -def _get_dummy_mcmc_samples( - num_samples: int, - num_outputs: int, - dtype: torch.dtype, - device: torch.device, - perturb_sd: float = 1e-6, -) -> Dict[str, torch.Tensor]: - tkwargs: Dict[str, Any] = {"dtype": dtype, "device": device} - dummy_sample_list = [] - for i in range(num_outputs): - dummy_samples = { - # use real MAP values with tiny perturbations - # so that the generation code below has feasible in-sample - # points - "lengthscale": (i + 1) * torch.tensor([[1 / 3, 1 / 3, 1 / 3]], **tkwargs) - + perturb_sd * torch.randn(num_samples, 1, 3, **tkwargs), - "outputscale": (i + 1) * torch.tensor(2.3436, **tkwargs) - + perturb_sd * torch.randn(num_samples, **tkwargs), - "mean": (i + 1) * torch.tensor([3.5000], **tkwargs) - + perturb_sd * torch.randn(num_samples, **tkwargs), - } - dummy_samples["kernel_tausq"] = (i + 1) * torch.tensor(0.5, **tkwargs) - dummy_samples["_kernel_inv_length_sq"] = ( - # pyre-fixme[6]: For 2nd param expected `Tensor` but got `float`. - 1.0 - # pyre-fixme[58]: `/` is not supported for operand types `float` and - # `Tensor`. - / dummy_samples["lengthscale"].sqrt() - ) - dummy_sample_list.append(dummy_samples) - # pyre-fixme[7]: Expected `Dict[str, Tensor]` but got `List[typing.Any]`. - return dummy_sample_list - - -def dummy_func(X: torch.Tensor) -> torch.Tensor: - return X - - -class BaseFullyBayesianBotorchModelTestCases: - class FullyBayesianBotorchModelTest(TestCase): - model_cls: Type[BotorchModel] = FullyBayesianBotorchModel - - def test_FullyBayesianBotorchModel( - self, dtype: torch.dtype = torch.float, cuda: bool = False - ) -> None: - # test deprecation warning - with self.assertWarnsRegex( - DeprecationWarning, "Passing `use_saas` is no longer supported" - ): - self.model_cls(use_saas=True) - Xs1, Ys1, Yvars1, bounds, tfs, fns, mns = get_torch_test_data( - dtype=dtype, cuda=cuda, constant_noise=True - ) - Xs2, Ys2, Yvars2, _, _, _, _ = get_torch_test_data( - dtype=dtype, cuda=cuda, constant_noise=True - ) - Yvars_inferred_noise = [ - torch.full_like(Yvars1[0], float("nan")), - torch.full_like(Yvars2[0], float("nan")), - ] - # make input different for each output - Xs2_diff = [Xs2[0] + 0.1] - Xs = Xs1 + Xs2_diff - Ys = Ys1 + Ys2 - - options = product([True, False], [True, False], ["matern", "rbf"]) - for inferred_noise, use_input_warping, gp_kernel in options: - Yvars = Yvars_inferred_noise if inferred_noise else Yvars1 + Yvars2 - model = self.model_cls( - use_input_warping=use_input_warping, - thinning=1, - num_samples=4, - disable_progbar=True, - max_tree_depth=1, - gp_kernel=gp_kernel, - verbose=True, - ) - if use_input_warping: - self.assertTrue(model.use_input_warping) - # Test ModelListGP - # make training data different for each output - tkwargs: Dict[str, Any] = {"dtype": dtype, "device": Xs1[0].device} - dummy_samples_list = _get_dummy_mcmc_samples( - num_samples=4, num_outputs=2, **tkwargs - ) - for dummy_samples in dummy_samples_list: - if use_input_warping: - # pyre-fixme[16]: `str` has no attribute `__setitem__`. - dummy_samples["c0"] = ( - torch.rand(4, 1, Xs1[0].shape[-1], **tkwargs) * 0.5 + 0.1 - ) - dummy_samples["c1"] = ( - torch.rand(4, 1, Xs1[0].shape[-1], **tkwargs) * 0.5 + 0.1 - ) - if inferred_noise: - dummy_samples["noise"] = torch.rand(4, 1, **tkwargs).clamp_min( - MIN_INFERRED_NOISE_LEVEL - ) - - with mock.patch( - RUN_INFERENCE_PATH, - side_effect=dummy_samples_list, - ) as _mock_fit_model: - model.fit( - datasets=[ - SupervisedDataset( - X=X, - Y=Y, - Yvar=Yvar, - feature_names=fns, - outcome_names=[f"y{i}"], - ) - for X, Y, Yvar, i in zip(Xs, Ys, Yvars, count()) - ], - search_space_digest=SearchSpaceDigest( - feature_names=fns, - bounds=bounds, - task_features=tfs, - ), - ) - - # Check that there are no unexpected constraints on the - # hyperparameters - for _, m in enumerate(cast(ModelList, model.model).models): - if inferred_noise: - noise_covar = m.likelihood.noise_covar - self.assertEqual( - noise_covar.raw_noise_constraint.__class__, GreaterThan - ) - self.assertTrue( - torch.allclose( - noise_covar.raw_noise_constraint.lower_bound, - torch.tensor(1e-4, dtype=dtype), - ) - ) - else: - self.assertFalse( - hasattr( - m.likelihood.noise_covar, "raw_noise_constraint" - ) - ) - - self.assertEqual(m.covar_module.__class__, ScaleKernel) - self.assertEqual( - m.covar_module.raw_outputscale_constraint.__class__, - Positive, - ) - self.assertEqual( - m.covar_module.base_kernel.__class__, - RBFKernel if gp_kernel == "rbf" else MaternKernel, - ) - self.assertEqual(m.covar_module.base_kernel.ard_num_dims, 3) - ls_constraint = ( - m.covar_module.base_kernel.raw_lengthscale_constraint - ) - self.assertEqual( - ls_constraint.__class__, - Positive, - ) - - # attribute `assertEqual`. - self.assertEqual(_mock_fit_model.call_count, 2) - for i, call in enumerate(_mock_fit_model.call_args_list): - _, ckwargs = call - X = Xs[i] - Y = Ys[i] - Yvar = Yvars[i] - self.assertIs(ckwargs["pyro_model"], single_task_pyro_model) - - self.assertTrue(torch.equal(ckwargs["X"], X)) - self.assertTrue(torch.equal(ckwargs["Y"], Y)) - if inferred_noise: - self.assertTrue(torch.isnan(ckwargs["Yvar"]).all()) - else: - self.assertTrue(torch.equal(ckwargs["Yvar"], Yvar)) - self.assertEqual(ckwargs["num_samples"], 4) - self.assertEqual(ckwargs["warmup_steps"], 512) - self.assertEqual(ckwargs["max_tree_depth"], 1) - self.assertTrue(ckwargs["disable_progbar"]) - self.assertFalse(ckwargs["jit_compile"]) - self.assertEqual( - ckwargs["use_input_warping"], use_input_warping - ) - self.assertEqual(ckwargs["gp_kernel"], gp_kernel) - self.assertTrue(ckwargs["verbose"]) - - # Check attributes - self.assertTrue(torch.equal(model.Xs[i], Xs[i])) - self.assertEqual(model.dtype, Xs[i].dtype) - self.assertEqual(model.device, Xs[i].device) - self.assertIsInstance(model.model, ModelListGP) - - # Check fitting - # Note each model in the model list is a batched model, where - # the batch dim corresponds to the MCMC samples - model_list = cast(ModelList, model.model).models - # Put model in `eval` mode to transform the train inputs. - m = model_list[i].eval() - # check mcmc samples - # pyre-fixme[6]: For 1st param expected `str` but got `int`. - dummy_samples = dummy_samples_list[i] - expected_train_inputs = Xs[i].expand(4, *Xs[i].shape) - if use_input_warping: - # train inputs should be warped inputs - expected_train_inputs = m.input_transform( - expected_train_inputs - ) - self.assertTrue( - torch.equal( - m.train_inputs[0], - expected_train_inputs, - ) - ) - self.assertTrue( - torch.equal( - m.train_targets, - Ys[i].view(1, -1).expand(4, Ys[i].numel()), - ) - ) - expected_noise = ( - # pyre-fixme[6]: For 1st param expected `Union[None, - # List[typing.Any], int, slice, Tensor, - # typing.Tuple[typing.Any, ...]]` but got `str`. - dummy_samples["noise"].view(m.likelihood.noise.shape) - if inferred_noise - else Yvars[i].view(1, -1).expand(4, Yvars[i].numel()) - ) - self.assertTrue( - torch.allclose( - m.likelihood.noise.detach(), - expected_noise, - ) - ) - self.assertIsInstance(m.likelihood, _GaussianLikelihoodBase) - self.assertTrue( - torch.allclose( - m.covar_module.base_kernel.lengthscale.detach(), - # pyre-fixme[6]: For 1st param expected `Union[None, - # List[typing.Any], int, slice, Tensor, - # typing.Tuple[typing.Any, ...]]` but got `str`. - dummy_samples["lengthscale"].view( - m.covar_module.base_kernel.lengthscale.shape - ), - rtol=1e-4, - atol=1e-6, - ) - ) - self.assertTrue( - torch.allclose( - m.covar_module.outputscale.detach(), - # pyre-fixme[6]: For 1st param expected `Union[None, - # List[typing.Any], int, slice, Tensor, - # typing.Tuple[typing.Any, ...]]` but got `str`. - dummy_samples["outputscale"].view( - m.covar_module.outputscale.shape - ), - ) - ) - self.assertTrue( - torch.allclose( - m.mean_module.constant.detach(), - # pyre-fixme[6]: For 1st param expected `Union[None, - # List[typing.Any], int, slice, Tensor, - # typing.Tuple[typing.Any, ...]]` but got `str`. - dummy_samples["mean"].view( - m.mean_module.constant.shape - ), - ) - ) - if use_input_warping: - self.assertTrue(hasattr(m, "input_transform")) - self.assertIsInstance(m.input_transform, Warp) - self.assertTrue( - torch.equal( - m.input_transform.concentration0, - # pyre-fixme[6]: For 1st param expected `str` - # but got `int`. - # pyre-fixme[6]: For 1st param expected - # `Union[None, List[typing.Any], int, slice, - # Tensor, typing.Tuple[typing.Any, ...]]` but got - # `str`. - dummy_samples_list[i]["c0"], - ) - ) - self.assertTrue( - torch.equal( - m.input_transform.concentration1, - # pyre-fixme[6]: For 1st param expected `str` - # but got `int`. - # pyre-fixme[6]: For 1st param expected - # `Union[None, List[typing.Any], int, slice, - # Tensor, typing.Tuple[typing.Any, ...]]` but got - # `str`. - dummy_samples_list[i]["c1"], - ) - ) - else: - self.assertFalse(hasattr(m, "input_transform")) - # test that multi-task is not implemented - ( - Xs_mt, - Ys_mt, - Yvars_mt, - bounds_mt, - tfs_mt, - fns_mt, - mns_mt, - ) = get_torch_test_data( - dtype=dtype, cuda=cuda, constant_noise=True, task_features=[2] - ) - with mock.patch( - RUN_INFERENCE_PATH, - side_effect=dummy_samples_list, - ) as _mock_fit_model, self.assertRaises(NotImplementedError): - model.fit( - datasets=[ - SupervisedDataset( - X=X, - Y=Y, - Yvar=Yvar, - feature_names=fns_mt, - outcome_names=[mn], - ) - for X, Y, Yvar, mn in zip(Xs_mt, Ys_mt, Yvars_mt, mns_mt) - ], - search_space_digest=SearchSpaceDigest( - feature_names=fns_mt, - bounds=bounds_mt, - task_features=tfs_mt, - ), - ) - with mock.patch( - RUN_INFERENCE_PATH, - side_effect=dummy_samples_list, - ) as _mock_fit_model, self.assertRaises(NotImplementedError): - model.fit( - datasets=[ - SupervisedDataset( - X=X, - Y=Y, - Yvar=Yvar, - feature_names=fns, - outcome_names=[mn], - ) - for X, Y, Yvar, mn in zip( - Xs1 + Xs2, Ys1 + Ys2, Yvars1 + Yvars2, mns * 2 - ) - ], - search_space_digest=SearchSpaceDigest( - feature_names=fns, - bounds=bounds, - fidelity_features=[0], - ), - ) - # fit model with same inputs (otherwise X_observed will be None) - model = self.model_cls( - use_input_warping=use_input_warping, - thinning=1, - num_samples=4, - disable_progbar=True, - max_tree_depth=1, - gp_kernel=gp_kernel, - ) - Yvars = Yvars1 + Yvars2 - dummy_samples_list = _get_dummy_mcmc_samples( - num_samples=4, num_outputs=2, **tkwargs - ) - with mock.patch( - RUN_INFERENCE_PATH, - side_effect=dummy_samples_list, - ) as _mock_fit_model: - model.fit( - datasets=[ - SupervisedDataset( - X=X, - Y=Y, - Yvar=Yvar, - feature_names=fns, - outcome_names=[mn], - ) - for X, Y, Yvar, mn in zip( - Xs1 + Xs2, Ys1 + Ys2, Yvars, mns * 2 - ) - ], - search_space_digest=SearchSpaceDigest( - feature_names=fns, - bounds=bounds, - task_features=tfs, - ), - ) - - # Check the hyperparameters and shapes - models = cast(ModelList, model.model).models - self.assertEqual(len(models), 2) - m1, m2 = models - # Mean - self.assertEqual(m1.mean_module.constant.shape, (4,)) - self.assertFalse( - torch.isclose( - m1.mean_module.constant, m2.mean_module.constant - ).any() - ) - # Outputscales - self.assertEqual(m1.covar_module.outputscale.shape, (4,)) - self.assertFalse( - torch.isclose( - m1.covar_module.outputscale, m2.covar_module.outputscale - ).any() - ) - # Lengthscales - self.assertEqual( - m1.covar_module.base_kernel.lengthscale.shape, (4, 1, 3) - ) - self.assertFalse( - torch.isclose( - m1.covar_module.base_kernel.lengthscale, - m2.covar_module.base_kernel.lengthscale, - ).any() - ) - - # Check infeasible cost can be computed on the model - device = torch.device("cuda") if cuda else torch.device("cpu") - objective_weights = torch.tensor([1.0, 0.0], dtype=dtype, device=device) - objective_transform = get_objective_weights_transform(objective_weights) - infeasible_cost = ( - get_infeasible_cost( - X=Xs1[0], - model=model.model, - objective=objective_transform, - ) - .detach() - .clone() - ) - posterior = cast(GPyTorchPosterior, model.model.posterior(Xs1[0])) - expected_infeasible_cost = -1 * torch.min( - # pyre-fixme[20]: Argument `1` expected. - objective_transform( - posterior.mean - 6 * posterior.variance.sqrt() - ).min(), - torch.tensor(0.0, dtype=dtype, device=device), - ) - self.assertTrue( - torch.abs(infeasible_cost - expected_infeasible_cost) < 1e-5 - ) - - # Check prediction - X = torch.tensor([[6.0, 7.0, 8.0]], **tkwargs) - f_mean, f_cov = model.predict(X) - self.assertTrue(f_mean.shape == torch.Size([1, 2])) - self.assertTrue(f_cov.shape == torch.Size([1, 2, 2])) - - # Check generation - objective_weights = torch.tensor( - ( - [1.0, 0.0] - if self.model_cls is FullyBayesianBotorchModel - else [1.0, 1.0] - ), - **tkwargs, - ) - outcome_constraints = ( - torch.tensor([[0.0, 1.0]], **tkwargs), - torch.tensor([[5.0]], **tkwargs), - ) - linear_constraints = ( - torch.tensor([[0.0, 1.0, 1.0]]), - torch.tensor([[100.0]]), - ) - fixed_features = None - pending_observations = [ - torch.tensor([[1.0, 3.0, 4.0]], **tkwargs), - torch.tensor([[2.0, 6.0, 8.0]], **tkwargs), - ] - n = 3 - - X_dummy = torch.tensor([[[1.0, 2.0, 3.0]]], **tkwargs) - acqfv_dummy = torch.tensor([[[1.0, 2.0, 3.0]]], **tkwargs) - model_gen_options = { - Keys.OPTIMIZER_KWARGS: {"options": {"maxiter": 1}}, - Keys.ACQF_KWARGS: {"mc_samples": 3}, - } - search_space_digest = SearchSpaceDigest( - feature_names=fns, - bounds=bounds, - ) - torch_opt_config = TorchOptConfig( - objective_weights=objective_weights, - objective_thresholds=( - torch.zeros(2, **tkwargs) - if self.model_cls is FullyBayesianMOOBotorchModel - else None - ), - outcome_constraints=outcome_constraints, - linear_constraints=linear_constraints, - fixed_features=fixed_features, - pending_observations=pending_observations, - # pyre-fixme[6]: For 7th param expected `Dict[str, Union[None, - # Dict[str, typing.Any], OptimizationConfig, AcquisitionFunction, - # float, int, str]]` but got `Dict[Keys, Dict[str, int]]`. - model_gen_options=model_gen_options, - rounding_func=dummy_func, - is_moo=self.model_cls is FullyBayesianMOOBotorchModel, - ) - # test sequential optimize with constraints - with mock.patch( - "ax.models.torch.botorch_defaults.optimize_acqf", - return_value=(X_dummy, acqfv_dummy), - ) as _: - # Xgen, wgen, gen_metadata, cand_metadata = model.gen( - gen_results = model.gen( - n=n, - search_space_digest=search_space_digest, - torch_opt_config=torch_opt_config, - ) - # note: gen() always returns CPU tensors - self.assertTrue(torch.equal(gen_results.points, X_dummy.cpu())) - self.assertTrue( - torch.equal(gen_results.weights, torch.ones(n, dtype=dtype)) - ) - - # actually test optimization for 1 step without constraints - with mock.patch( - "ax.models.torch.botorch_defaults.optimize_acqf", - wraps=optimize_acqf, - return_value=(X_dummy, acqfv_dummy), - ) as _: - gen_results = model.gen( - n=n, - search_space_digest=search_space_digest, - torch_opt_config=dataclasses.replace( - torch_opt_config, linear_constraints=None - ), - ) - # note: gen() always returns CPU tensors - self.assertTrue(torch.equal(gen_results.points, X_dummy.cpu())) - self.assertTrue( - torch.equal(gen_results.weights, torch.ones(n, dtype=dtype)) - ) - - # Check best point selection - if self.model_cls is FullyBayesianMOOBotorchModel: - with self.assertRaisesRegex(NotImplementedError, "Best observed"): - model.best_point( - search_space_digest=search_space_digest, - torch_opt_config=torch_opt_config, - ) - else: - self.assertIsNotNone( - model.best_point( - search_space_digest=search_space_digest, - torch_opt_config=torch_opt_config, - ) - ) - self.assertIsNone( - model.best_point( - search_space_digest=search_space_digest, - torch_opt_config=dataclasses.replace( - torch_opt_config, fixed_features={0: 100.0} - ), - ) - ) - - # Test cross-validation - mean, variance = model.cross_validate( - datasets=[ - SupervisedDataset( - X=X, Y=Y, Yvar=Yvar, feature_names=fns, outcome_names=[mn] - ) - for X, Y, Yvar, mn in zip(Xs1 + Xs2, Ys1 + Ys2, Yvars, mns * 2) - ], - X_test=torch.tensor( - [[1.2, 3.2, 4.2], [2.4, 5.2, 3.2]], dtype=dtype, device=device - ), - ) - self.assertTrue(mean.shape == torch.Size([2, 2])) - self.assertTrue(variance.shape == torch.Size([2, 2, 2])) - - # Test cross-validation with refit_on_cv - model.refit_on_cv = True - with mock.patch( - RUN_INFERENCE_PATH, - side_effect=dummy_samples_list, - ) as _mock_fit_model: - mean, variance = model.cross_validate( - datasets=[ - SupervisedDataset( - X=X, - Y=Y, - Yvar=Yvar, - feature_names=fns, - outcome_names=[mn], - ) - for X, Y, Yvar, mn in zip( - Xs1 + Xs2, Ys1 + Ys2, Yvars, mns * 2 - ) - ], - X_test=torch.tensor( - [[1.2, 3.2, 4.2], [2.4, 5.2, 3.2]], - dtype=dtype, - device=device, - ), - ) - self.assertTrue(mean.shape == torch.Size([2, 2])) - self.assertTrue(variance.shape == torch.Size([2, 2, 2])) - - # Test feature_importances - importances = model.feature_importances() - self.assertEqual(importances.shape, torch.Size([2, 1, 3])) - - # test unfit model CV and feature_importances - unfit_model = self.model_cls() - with self.assertRaises(RuntimeError): - unfit_model.cross_validate( - datasets=[ - SupervisedDataset( - X=X, - Y=Y, - Yvar=Yvar, - feature_names=fns, - outcome_names=[mn], - ) - for X, Y, Yvar, mn in zip( - Xs1 + Xs2, Ys1 + Ys2, Yvars1 + Yvars2, mns * 2 - ) - ], - X_test=Xs1[0], - ) - with self.assertRaises(RuntimeError): - unfit_model.feature_importances() - - def test_saasbo_sample(self) -> None: - for use_input_warping, gp_kernel in product( - [False, True], ["rbf", "matern"] - ): - with with_rng_seed(0): - X = torch.randn(3, 2) - Y = torch.randn(3, 1) - Yvar = torch.randn(3, 1) - kernel = NUTS(single_task_pyro_model, max_tree_depth=1) - mcmc = MCMC(kernel, warmup_steps=0, num_samples=1) - mcmc.run( - X, - Y, - Yvar, - use_input_warping=use_input_warping, - gp_kernel=gp_kernel, - ) - samples = mcmc.get_samples() - self.assertTrue("kernel_tausq" in samples) - self.assertTrue("_kernel_inv_length_sq" in samples) - self.assertTrue("lengthscale" not in samples) - if use_input_warping: - self.assertIn("c0", samples) - self.assertIn("c1", samples) - else: - self.assertNotIn("c0", samples) - self.assertNotIn("c1", samples) - - def test_gp_kernels(self) -> None: - set_rng_seed(0) - X = torch.randn(3, 2) - Y = torch.randn(3, 1) - Yvar = torch.randn(3, 1) - kernel = NUTS(single_task_pyro_model, max_tree_depth=1) - with self.assertRaises(ValueError): - mcmc = MCMC(kernel, warmup_steps=0, num_samples=1) - mcmc.run( - X, - Y, - Yvar, - gp_kernel="some_kernel_we_dont_support", - ) - - def test_FullyBayesianBotorchModel_cuda(self) -> None: - if torch.cuda.is_available(): - self.test_FullyBayesianBotorchModel(cuda=True) - - def test_FullyBayesianBotorchModel_double(self) -> None: - self.test_FullyBayesianBotorchModel(dtype=torch.double) - - def test_FullyBayesianBotorchModel_double_cuda(self) -> None: - if torch.cuda.is_available(): - self.test_FullyBayesianBotorchModel(dtype=torch.double, cuda=True) - - def test_FullyBayesianBotorchModelConstraints(self) -> None: - Xs1, Ys1, Yvars1, bounds, tfs, fns, mns = get_torch_test_data( - dtype=torch.float, cuda=False, constant_noise=True - ) - Xs2, Ys2, Yvars2, _, _, _, _ = get_torch_test_data( - dtype=torch.float, cuda=False, constant_noise=True - ) - # make infeasible - Xs2[0] = -1 * Xs2[0] - objective_weights = torch.tensor( - [-1.0, 1.0], dtype=torch.float, device=torch.device("cpu") - ) - n = 3 - model = self.model_cls( - num_samples=4, - thinning=1, - disable_progbar=True, - max_tree_depth=1, - ) - dummy_samples = _get_dummy_mcmc_samples( - num_samples=4, num_outputs=2, dtype=torch.float, device=Xs1[0].device - ) - search_space_digest = SearchSpaceDigest( - feature_names=fns, - bounds=bounds, - task_features=tfs, - ) - with mock.patch( - RUN_INFERENCE_PATH, side_effect=dummy_samples - ) as _mock_fit_model: - model.fit( - datasets=[ - SupervisedDataset( - X=X, Y=Y, Yvar=Yvar, feature_names=fns, outcome_names=[mn] - ) - for X, Y, Yvar, mn in zip( - Xs1 + Xs2, Ys1 + Ys2, Yvars1 + Yvars2, mns * 2 - ) - ], - search_space_digest=search_space_digest, - ) - self.assertEqual(_mock_fit_model.call_count, 2) - - # because there are no feasible points: - with self.assertRaises(ValueError): - model.gen( - n, - search_space_digest=search_space_digest, - torch_opt_config=TorchOptConfig( - objective_weights=objective_weights - ), - ) - - # pyre-fixme[3]: Return type must be annotated. - # pyre-fixme[2]: Parameter must be annotated. - def test_FullyBayesianBotorchModelPyro(self, dtype=torch.double, cuda=False): - Xs1, Ys1, raw_Yvars1, bounds, tfs, fns, mns = get_torch_test_data( - dtype=dtype, cuda=cuda, constant_noise=True - ) - Xs2, Ys2, raw_Yvars2, _, _, _, _ = get_torch_test_data( - dtype=dtype, cuda=cuda, constant_noise=True - ) - options = [(False, True, "rbf"), (True, False, "matern")] - for inferred_noise, use_input_warping, gp_kernel in options: - model = self.model_cls( - num_samples=4, - warmup_steps=0, - thinning=1, - use_input_warping=use_input_warping, - disable_progbar=True, - max_tree_depth=1, - gp_kernel=gp_kernel, - verbose=True, - ) - if inferred_noise: - Yvars1 = [torch.full_like(raw_Yvars1[0], float("nan"))] - Yvars2 = [torch.full_like(raw_Yvars2[0], float("nan"))] - else: - Yvars1 = raw_Yvars1 - Yvars2 = raw_Yvars2 - - dummy_samples = _get_dummy_mcmc_samples( - num_samples=4, - num_outputs=2, - dtype=dtype, - device=Xs1[0].device, - ) - with ExitStack() as es: - _mock_fit_model = es.enter_context( - mock.patch(RUN_INFERENCE_PATH, side_effect=dummy_samples) - ) - model.fit( - datasets=[ - SupervisedDataset( - X=X, - Y=Y, - Yvar=Yvar, - feature_names=fns, - outcome_names=[mn], - ) - for X, Y, Yvar, mn in zip( - Xs1 + Xs2, Ys1 + Ys2, Yvars1 + Yvars2, mns * 2 - ) - ], - search_space_digest=SearchSpaceDigest( - feature_names=fns, - bounds=bounds, - task_features=tfs, - ), - ) - # check run_inference arguments - self.assertEqual(_mock_fit_model.call_count, 2) - _, ckwargs = _mock_fit_model.call_args - self.assertIs(ckwargs["pyro_model"], single_task_pyro_model) - self.assertFalse(ckwargs["jit_compile"]) - self.assertTrue(torch.equal(ckwargs["X"], Xs1[0])) - self.assertTrue(torch.equal(ckwargs["Y"], Ys1[0])) - if inferred_noise: - self.assertTrue(torch.isnan(ckwargs["Yvar"]).all()) - else: - self.assertTrue(torch.equal(ckwargs["Yvar"], Yvars1[0])) - self.assertEqual(ckwargs["num_samples"], 4) - self.assertEqual(ckwargs["warmup_steps"], 0) - self.assertEqual(ckwargs["max_tree_depth"], 1) - self.assertTrue(ckwargs["disable_progbar"]) - self.assertFalse(ckwargs["jit_compile"]) - self.assertEqual(ckwargs["use_input_warping"], use_input_warping) - self.assertEqual(ckwargs["gp_kernel"], gp_kernel) - self.assertTrue(ckwargs["verbose"]) - - with ExitStack() as es: - _mock_mcmc = es.enter_context(mock.patch(MCMC_PATH)) - _mock_mcmc.return_value.get_samples.side_effect = dummy_samples - _mock_nuts = es.enter_context(mock.patch(NUTS_PATH)) - model.fit( - datasets=[ - SupervisedDataset( - X=X, - Y=Y, - Yvar=Yvar, - feature_names=fns, - outcome_names=[mn], - ) - for X, Y, Yvar, mn in zip( - Xs1 + Xs2, Ys1 + Ys2, Yvars1 + Yvars2, mns * 2 - ) - ], - search_space_digest=SearchSpaceDigest( - feature_names=fns, - bounds=bounds, - task_features=tfs, - ), - ) - # check MCMC.__init__ arguments - self.assertEqual(_mock_mcmc.call_count, 2) - _, ckwargs = _mock_mcmc.call_args - self.assertEqual(ckwargs["num_samples"], 4) - self.assertEqual(ckwargs["warmup_steps"], 0) - self.assertTrue(ckwargs["disable_progbar"]) - # check NUTS.__init__ arguments - _mock_nuts.assert_called_with( - single_task_pyro_model, - jit_compile=False, - full_mass=True, - ignore_jit_warnings=True, - max_tree_depth=1, - ) - # now actually run pyro - if not use_input_warping: - # input warping is quite slow, so we omit it for - # testing purposes - model.fit( - datasets=[ - SupervisedDataset( - X=X, - Y=Y, - Yvar=Yvar, - feature_names=fns, - outcome_names=[mn], - ) - for X, Y, Yvar, mn in zip( - Xs1 + Xs2, Ys1 + Ys2, Yvars1 + Yvars2, mns * 2 - ) - ], - search_space_digest=SearchSpaceDigest( - feature_names=fns, - bounds=bounds, - task_features=tfs, - ), - ) - - for m, X, Y, Yvar in zip( - cast(ModelList, model.model).models, - Xs1 + Xs2, - Ys1 + Ys2, - Yvars1 + Yvars2, - ): - self.assertTrue( - torch.equal( - m.train_inputs[0], - X.expand(4, *X.shape), - ) - ) - self.assertTrue( - torch.equal( - m.train_targets, - Y.view(1, -1).expand(4, Y.numel()), - ) - ) - # check shapes of sampled parameters - if not inferred_noise: - self.assertTrue( - torch.allclose( - m.likelihood.noise.detach(), - Yvar.view(1, -1).expand(4, Yvar.numel()), - ) - ) - else: - self.assertEqual( - m.likelihood.noise.shape, torch.Size([4, 1]) - ) - - self.assertEqual( - m.covar_module.base_kernel.lengthscale.shape, - torch.Size([4, 1, X.shape[-1]]), - ) - self.assertEqual( - m.covar_module.outputscale.shape, torch.Size([4]) - ) - self.assertEqual( - m.mean_module.constant.shape, - torch.Size([4]), - ) - if use_input_warping: - self.assertTrue(hasattr(m, "input_transform")) - self.assertIsInstance(m.input_transform, Warp) - self.assertEqual( - m.input_transform.concentration0.shape, - torch.Size([4, 1, 3]), - ) - self.assertEqual( - m.input_transform.concentration1.shape, - torch.Size([4, 1, 3]), - ) - else: - self.assertFalse(hasattr(m, "input_transform")) - - def test_FullyBayesianBotorchModelPyro_float(self) -> None: - self.test_FullyBayesianBotorchModelPyro(dtype=torch.float, cuda=False) - - def test_FullyBayesianBotorchModelPyro_cuda_double(self) -> None: - if torch.cuda.is_available(): - self.test_FullyBayesianBotorchModelPyro(dtype=torch.double, cuda=True) - - def test_FullyBayesianBotorchModelPyro_cuda_float(self) -> None: - if torch.cuda.is_available(): - self.test_FullyBayesianBotorchModelPyro(dtype=torch.float, cuda=True) - - -class FullyBayesianBotorchModelTest( - BaseFullyBayesianBotorchModelTestCases.FullyBayesianBotorchModelTest -): - def test_FullyBayesianBotorchModelOneOutcome(self) -> None: - Xs1, Ys1, Yvars1, bounds, tfs, fns, mns = get_torch_test_data( - dtype=torch.float, cuda=False, constant_noise=True - ) - for use_input_warping, gp_kernel in product([True, False], ["rbf", "matern"]): - model = self.model_cls( - use_input_warping=use_input_warping, - num_samples=4, - thinning=1, - disable_progbar=True, - max_tree_depth=1, - gp_kernel=gp_kernel, - ) - dummy_samples = _get_dummy_mcmc_samples( - num_samples=4, - num_outputs=1, - dtype=torch.float, - device=Xs1[0].device, - ) - with mock.patch( - RUN_INFERENCE_PATH, side_effect=dummy_samples - ) as _mock_fit_model: - model.fit( - datasets=[ - SupervisedDataset( - X=X, Y=Y, Yvar=Yvar, feature_names=fns, outcome_names=[mn] - ) - for X, Y, Yvar, mn in zip(Xs1, Ys1, Yvars1, mns) - ], - search_space_digest=SearchSpaceDigest( - feature_names=fns, - bounds=bounds, - task_features=tfs, - ), - ) - _mock_fit_model.assert_called_once() - X = torch.rand(2, 3, dtype=torch.float) - f_mean, f_cov = model.predict(X) - self.assertTrue(f_mean.shape == torch.Size([2, 1])) - self.assertTrue(f_cov.shape == torch.Size([2, 1, 1])) - model_list = cast(ModelList, model.model).models - self.assertTrue(len(model_list) == 1) - if use_input_warping: - self.assertTrue(hasattr(model_list[0], "input_transform")) - self.assertIsInstance(model_list[0].input_transform, Warp) - else: - self.assertFalse(hasattr(model_list[0], "input_transform")) - - -class FullyBayesianMOOBotorchModelTest( - BaseFullyBayesianBotorchModelTestCases.FullyBayesianBotorchModelTest -): - model_cls = FullyBayesianMOOBotorchModel - - -class TestKernels(TestCase): - def test_matern_kernel(self) -> None: - a = torch.tensor([4, 2, 8], dtype=torch.float).view(3, 1) - b = torch.tensor([0, 2], dtype=torch.float).view(2, 1) - lengthscale = 2 - # test matern 1/2 - # pyre-fixme[6]: For 4th param expected `Tensor` but got `int`. - res = matern_kernel(a, b, nu=0.5, lengthscale=lengthscale) - actual = ( - torch.tensor([[4, 2], [2, 0], [8, 6]], dtype=torch.float) - .div_(-lengthscale) - .exp() - ) - self.assertLess(torch.linalg.norm(res - actual), 1e-3) - # matern test 3/2 - # pyre-fixme[6]: For 4th param expected `Tensor` but got `int`. - res = matern_kernel(a, b, nu=1.5, lengthscale=lengthscale) - dist = torch.tensor([[4, 2], [2, 0], [8, 6]], dtype=torch.float).mul_( - sqrt(3) / lengthscale - ) - actual = (dist + 1).mul(torch.exp(-dist)) - self.assertLess(torch.linalg.norm(res - actual), 1e-3) - # matern test 5/2 - # pyre-fixme[6]: For 4th param expected `Tensor` but got `int`. - res = matern_kernel(a, b, nu=2.5, lengthscale=lengthscale) - dist = torch.tensor([[4, 2], [2, 0], [8, 6]], dtype=torch.float).mul_( - sqrt(5) / lengthscale - ) - # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and `int`. - actual = (dist**2 / 3 + dist + 1).mul(torch.exp(-dist)) - self.assertLess(torch.linalg.norm(res - actual), 1e-3) - - # test k(x,x) with no gradients - # pyre-fixme[6]: For 4th param expected `Tensor` but got `float`. - res = matern_kernel(b, b, nu=0.5, lengthscale=2.0) - actual = ( - torch.tensor([[0, 2], [2, 0]], dtype=torch.float).div_(-lengthscale).exp() - ) - self.assertLess(torch.linalg.norm(res - actual), 1e-3) - - # test unsupported nu - with self.assertRaises(AxError): - # pyre-fixme[6]: For 4th param expected `Tensor` but got `float`. - matern_kernel(b, b, nu=0.0, lengthscale=2.0) - - def test_rbf_kernel(self) -> None: - a = torch.tensor([4, 2, 8], dtype=torch.float).view(3, 1) - b = torch.tensor([0, 2], dtype=torch.float).view(2, 1) - lengthscale = 2 - # test rbf - # pyre-fixme[6]: For 3rd param expected `Tensor` but got `int`. - res = rbf_kernel(a, b, lengthscale=lengthscale) - actual = ( - torch.tensor([[4, 2], [2, 0], [8, 6]], dtype=torch.float) - .pow_(2.0) - .mul_(-0.5 / (lengthscale**2)) - .exp() - ) - self.assertLess(torch.linalg.norm(res - actual), 1e-3) - - -class TestPyroCatchNumericalErrors(TestCase): - # This test is to verify that the pyro exception handlers are properly registered, - # which should happen upon importing from botorch.models.fully_bayesian.py - # (which in turn happens within ax.models.torch.fully_bayesian.py). - - def test_pyro_catch_error(self) -> None: - def potential_fn(z: Dict[str, torch.Tensor]) -> torch.Tensor: - # pyre-fixme[16]: Module `distributions` has no attribute - # `MultivariateNormal`. - mvn = pyro.distributions.MultivariateNormal( - loc=torch.zeros(2), - covariance_matrix=z["K"], - ) - return mvn.log_prob(torch.zeros(2)) - - # Test base case where everything is fine - z = {"K": torch.eye(2)} - grads, val = potential_grad(potential_fn, z) - self.assertTrue(torch.allclose(grads["K"], -0.5 * torch.eye(2))) - norm_mvn = torch.distributions.Normal(0, 1) - self.assertTrue(torch.allclose(val, 2 * norm_mvn.log_prob(torch.zeros(1)))) - - # Default behavior should catch the ValueError when trying to instantiate - # the MVN and return NaN instead - z = {"K": torch.ones(2, 2)} - _, val = potential_grad(potential_fn, z) - self.assertTrue(torch.isnan(val)) - - # Default behavior should catch the LinAlgError when peforming a - # Cholesky decomposition and return NaN instead - def potential_fn_chol(z: Dict[str, torch.Tensor]) -> torch.Tensor: - return torch.linalg.cholesky(z["K"]) - - _, val = potential_grad(potential_fn_chol, z) - self.assertTrue(torch.isnan(val)) - - # Default behavior should not catch other errors - def potential_fn_rterr_foo(z: Dict[str, torch.Tensor]) -> torch.Tensor: - raise RuntimeError("foo") - - with self.assertRaisesRegex(RuntimeError, "foo"): - potential_grad(potential_fn_rterr_foo, z) diff --git a/ax/models/torch/fully_bayesian.py b/ax/models/torch/fully_bayesian.py deleted file mode 100644 index ec0a65345c1..00000000000 --- a/ax/models/torch/fully_bayesian.py +++ /dev/null @@ -1,654 +0,0 @@ -#!/usr/bin/env python3 -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. - -# pyre-strict - - -""" -Models and utilities for fully bayesian inference. - -TODO: move some of this into botorch. - -References - -.. [Eriksson2021saasbo] - D. Eriksson, M. Jankowiak. High-Dimensional Bayesian Optimization - with Sparse Axis-Aligned Subspaces. Proceedings of the Thirty- - Seventh Conference on Uncertainty in Artificial Intelligence, 2021. - -.. [Eriksson2021nas] - D. Eriksson, P. Chuang, S. Daulton, et al. Latency-Aware Neural - Architecture Search with Multi-Objective Bayesian Optimization. - ICML AutoML Workshop, 2021. - -""" - -import math -import sys -import time -import types -import warnings - -from logging import Logger -from typing import Any, Callable, Dict, List, Optional, Tuple - -import numpy as np - -import pyro -import torch -from ax.exceptions.core import AxError -from ax.models.torch.botorch import ( - BotorchModel, - TAcqfConstructor, - TBestPointRecommender, - TModelConstructor, - TModelPredictor, - TOptimizer, -) -from ax.models.torch.botorch_defaults import ( - get_qLogNEI, - MIN_OBSERVED_NOISE_LEVEL, - recommend_best_observed_point, - scipy_optimizer, -) -from ax.models.torch.botorch_moo import MultiObjectiveBotorchModel -from ax.models.torch.botorch_moo_defaults import ( - get_qLogNEHVI, - pareto_frontier_evaluator, -) -from ax.models.torch.frontier_utils import TFrontierEvaluator -from ax.models.torch.fully_bayesian_model_utils import ( - _get_single_task_gpytorch_model, - load_mcmc_samples_to_model, - pyro_sample_input_warping, - pyro_sample_mean, - pyro_sample_noise, - pyro_sample_outputscale, - pyro_sample_saas_lengthscales, -) -from ax.utils.common.docutils import copy_doc -from ax.utils.common.logger import get_logger -from ax.utils.common.typeutils import checked_cast -from botorch.acquisition import AcquisitionFunction -from botorch.models.gpytorch import GPyTorchModel -from botorch.models.model import Model -from botorch.models.model_list_gp_regression import ModelListGP -from botorch.posteriors.gpytorch import GPyTorchPosterior -from botorch.utils.safe_math import logmeanexp -from gpytorch.kernels.kernel import dist -from torch import Tensor - -logger: Logger = get_logger(__name__) - - -SAAS_DEPRECATION_MSG = ( - "Passing `use_saas` is no longer supported and has no effect. " - "SAAS priors are used by default. " - "This will become an error in the future." -) - - -def predict_from_model_mcmc(model: Model, X: Tensor) -> Tuple[Tensor, Tensor]: - r"""Predicts outcomes given a model and input tensor. - - This method integrates over the hyperparameter posterior. - - Args: - model: A batched botorch Model where the batch dimension corresponds - to sampled hyperparameters. - X: A `n x d` tensor of input parameters. - - Returns: - Tensor: The predicted posterior mean as an `n x o`-dim tensor. - Tensor: The predicted posterior covariance as a `n x o x o`-dim tensor. - """ - with torch.no_grad(): - # compute the batch (independent posterior over the inputs) - posterior = checked_cast(GPyTorchPosterior, model.posterior(X.unsqueeze(-3))) - # the mean and variance both have shape: n x num_samples x m (after squeezing) - mean = posterior.mean.cpu().detach() - # TODO: Allow Posterior to (optionally) return the full covariance matrix - variance = posterior.variance.cpu().detach().clamp_min(0) - # marginalize over samples - t1 = variance.sum(dim=0) / variance.shape[0] - t2 = mean.pow(2).sum(dim=0) / variance.shape[0] - t3 = -(mean.sum(dim=0) / variance.shape[0]).pow(2) - variance = t1 + t2 + t3 - mean = mean.mean(dim=0) - cov = torch.diag_embed(variance) - return mean, cov - - -def compute_dists(X: Tensor, Z: Tensor, lengthscale: Tensor) -> Tensor: - """Compute kernel distances.""" - mean = X.mean(dim=0) - x1 = (X - mean).div(lengthscale) - x2 = (Z - mean).div(lengthscale) - return dist(x1=x1, x2=x2, x1_eq_x2=torch.equal(x1, x2)) - - -def matern_kernel(X: Tensor, Z: Tensor, lengthscale: Tensor, nu: float = 2.5) -> Tensor: - """Scaled Matern kernel.""" - dist = compute_dists(X=X, Z=Z, lengthscale=lengthscale) - exp_component = torch.exp(-math.sqrt(nu * 2) * dist) - - if nu == 0.5: - constant_component = 1 - elif nu == 1.5: - constant_component = (math.sqrt(3) * dist).add(1) - elif nu == 2.5: - # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and `int`. - constant_component = (math.sqrt(5) * dist).add(1).add(5.0 / 3.0 * (dist**2)) - else: - raise AxError(f"Unsupported value of nu: {nu}") - return constant_component * exp_component - - -def rbf_kernel(X: Tensor, Z: Tensor, lengthscale: Tensor) -> Tensor: - """Scaled RBF kernel.""" - dist = compute_dists(X=X, Z=Z, lengthscale=lengthscale) - # pyre-fixme[6]: For 1st param expected `Tensor` but got `float`. - # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and `int`. - return torch.exp(-0.5 * (dist**2)) - - -def single_task_pyro_model( - X: Tensor, - Y: Tensor, - Yvar: Tensor, - use_input_warping: bool = False, - eps: float = 1e-7, - gp_kernel: str = "matern", - task_feature: Optional[int] = None, - rank: Optional[int] = None, -) -> None: - r"""Instantiates a single task pyro model for running fully bayesian inference. - - Args: - X: A `n x d` tensor of input parameters. - Y: A `n x 1` tensor of output. - Yvar: A `n x 1` tensor of observed noise. - use_input_warping: A boolean indicating whether to use input warping - task_feature: Column index of task feature in X. - gp_kernel: kernel name. Currently only two kernels are supported: "matern" for - Matern Kernel and "rbf" for RBFKernel. - rank: num of latent task features to learn for task covariance. - """ - Y = Y.view(-1) - Yvar = Yvar.view(-1) - tkwargs = {"dtype": X.dtype, "device": X.device} - dim = X.shape[-1] - # TODO: test alternative outputscale priors - outputscale = pyro_sample_outputscale(concentration=2.0, rate=0.15, **tkwargs) - mean = pyro_sample_mean(**tkwargs) - if torch.isnan(Yvar).all(): - # infer noise level - noise = MIN_OBSERVED_NOISE_LEVEL + pyro_sample_noise(**tkwargs) - else: - noise = Yvar.clamp_min(MIN_OBSERVED_NOISE_LEVEL) - # pyre-fixme[6]: For 2nd param expected `float` but got `Union[device, dtype]`. - lengthscale = pyro_sample_saas_lengthscales(dim=dim, **tkwargs) - - # transform inputs through kumaraswamy cdf - if use_input_warping: - c0, c1 = pyro_sample_input_warping(dim=dim, **tkwargs) - # unnormalize X from [0, 1] to [eps, 1-eps] - X = (X * (1 - 2 * eps) + eps).clamp(eps, 1 - eps) - X_tf = 1 - torch.pow((1 - torch.pow(X, c1)), c0) - else: - X_tf = X - # compute kernel - if gp_kernel == "matern": - K = matern_kernel(X=X_tf, Z=X_tf, lengthscale=lengthscale) - elif gp_kernel == "rbf": - K = rbf_kernel(X=X_tf, Z=X_tf, lengthscale=lengthscale) - else: - raise ValueError(f"Expected kernel to be 'rbf' or 'matern', got {gp_kernel}") - - # add noise - K = outputscale * K + noise * torch.eye(X.shape[0], dtype=X.dtype, device=X.device) - - pyro.sample( - "Y", - # pyre-fixme[16]: Module `distributions` has no attribute `MultivariateNormal`. - pyro.distributions.MultivariateNormal( - loc=mean.view(-1).expand(X.shape[0]), - covariance_matrix=K, - ), - obs=Y, - ) - - -def _get_model_mcmc_samples( - Xs: List[Tensor], - Ys: List[Tensor], - Yvars: List[Tensor], - task_features: List[int], - fidelity_features: List[int], - metric_names: List[str], - state_dict: Optional[Dict[str, Tensor]] = None, - refit_model: bool = True, - use_input_warping: bool = False, - use_loocv_pseudo_likelihood: bool = False, - num_samples: int = 256, - warmup_steps: int = 512, - thinning: int = 16, - max_tree_depth: int = 6, - disable_progbar: bool = False, - gp_kernel: str = "matern", - verbose: bool = False, - jit_compile: bool = False, - # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. - pyro_model: Callable = single_task_pyro_model, - # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. - get_gpytorch_model: Callable = _get_single_task_gpytorch_model, - rank: Optional[int] = 1, - **kwargs: Any, - # pyre-fixme[24]: Generic type `dict` expects 2 type parameters, use `typing.Dict` - # to avoid runtime subscripting errors. -) -> Tuple[ModelListGP, List[Dict]]: - r"""Instantiates a batched GPyTorchModel(ModelListGP) based on the given data and - fit the model based on MCMC in pyro. - - Args: - pyro_model: callable to instantiate a pyro model for running MCMC - get_gpytorch_model: callable to instantiate a coupled GPyTorchModel to load the - returned MCMC samples. - """ - model = get_gpytorch_model( - Xs=Xs, - Ys=Ys, - Yvars=Yvars, - task_features=task_features, - fidelity_features=fidelity_features, - state_dict=state_dict, - num_samples=num_samples, - thinning=thinning, - use_input_warping=use_input_warping, - gp_kernel=gp_kernel, - **kwargs, - ) - if state_dict is not None: - # Expected `OrderedDict[typing.Any, typing.Any]` for 1st - # param but got `Dict[str, Tensor]`. - model.load_state_dict(state_dict) - - mcmc_samples_list = [] - if len(task_features) > 0: - task_feature = task_features[0] - else: - task_feature = None - if state_dict is None or refit_model: - for X, Y, Yvar in zip(Xs, Ys, Yvars): - mcmc_samples = run_inference( - pyro_model=pyro_model, - X=X, - Y=Y, - Yvar=Yvar, - num_samples=num_samples, - warmup_steps=warmup_steps, - thinning=thinning, - use_input_warping=use_input_warping, - max_tree_depth=max_tree_depth, - disable_progbar=disable_progbar, - gp_kernel=gp_kernel, - verbose=verbose, - task_feature=task_feature, - rank=rank, - jit_compile=jit_compile, - ) - mcmc_samples_list.append(mcmc_samples) - return model, mcmc_samples_list - - -def get_and_fit_model_mcmc( - Xs: List[Tensor], - Ys: List[Tensor], - Yvars: List[Tensor], - task_features: List[int], - fidelity_features: List[int], - metric_names: List[str], - state_dict: Optional[Dict[str, Tensor]] = None, - refit_model: bool = True, - use_input_warping: bool = False, - use_loocv_pseudo_likelihood: bool = False, - num_samples: int = 256, - warmup_steps: int = 512, - thinning: int = 16, - max_tree_depth: int = 6, - disable_progbar: bool = False, - gp_kernel: str = "matern", - verbose: bool = False, - jit_compile: bool = False, - **kwargs: Any, -) -> GPyTorchModel: - r"""Instantiates a batched GPyTorchModel(ModelListGP) based on the given data and - fit the model based on MCMC in pyro. The batch dimension corresponds to sampled - hyperparameters from MCMC. - """ - model, mcmc_samples_list = _get_model_mcmc_samples( - Xs=Xs, - Ys=Ys, - Yvars=Yvars, - task_features=task_features, - fidelity_features=fidelity_features, - metric_names=metric_names, - state_dict=state_dict, - refit_model=refit_model, - use_input_warping=use_input_warping, - use_loocv_pseudo_likelihood=use_loocv_pseudo_likelihood, - num_samples=num_samples, - warmup_steps=warmup_steps, - thinning=thinning, - max_tree_depth=max_tree_depth, - disable_progbar=disable_progbar, - gp_kernel=gp_kernel, - verbose=verbose, - jit_compile=jit_compile, - pyro_model=single_task_pyro_model, - get_gpytorch_model=_get_single_task_gpytorch_model, - ) - for i, mcmc_samples in enumerate(mcmc_samples_list): - load_mcmc_samples_to_model(model=model.models[i], mcmc_samples=mcmc_samples) - return model - - -def run_inference( - # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. - pyro_model: Callable, - X: Tensor, - Y: Tensor, - Yvar: Tensor, - num_samples: int = 256, - warmup_steps: int = 512, - thinning: int = 16, - use_input_warping: bool = False, - max_tree_depth: int = 6, - disable_progbar: bool = False, - gp_kernel: str = "matern", - verbose: bool = False, - task_feature: Optional[int] = None, - rank: Optional[int] = None, - jit_compile: bool = False, -) -> Dict[str, Tensor]: - start = time.time() - try: - from pyro.infer.mcmc import MCMC, NUTS - from pyro.infer.mcmc.util import print_summary - except ImportError: - raise RuntimeError("Cannot call run_inference without pyro installed!") - kernel = NUTS( - pyro_model, - jit_compile=jit_compile, - full_mass=True, - ignore_jit_warnings=True, - max_tree_depth=max_tree_depth, - ) - mcmc = MCMC( - kernel, - warmup_steps=warmup_steps, - num_samples=num_samples, - disable_progbar=disable_progbar, - ) - mcmc.run( - X, - Y, - Yvar, - use_input_warping=use_input_warping, - gp_kernel=gp_kernel, - task_feature=task_feature, - rank=rank, - ) - - # compute the true lengthscales and get rid of the temporary variables - samples = mcmc.get_samples() - inv_length_sq = ( - samples["kernel_tausq"].unsqueeze(-1) * samples["_kernel_inv_length_sq"] - ) - samples["lengthscale"] = (1.0 / inv_length_sq).sqrt() # pyre-ignore [16] - del samples["kernel_tausq"], samples["_kernel_inv_length_sq"] - # this prints the summary - if verbose: - orig_std_out = sys.stdout.write - sys.stdout.write = logger.info # pyre-fixme[8] - print_summary(samples, prob=0.9, group_by_chain=False) - sys.stdout.write = orig_std_out - logger.info(f"MCMC elapsed time: {time.time() - start}") - # thin - for k, v in samples.items(): - samples[k] = v[::thinning] # apply thinning - return samples - - -def get_fully_bayesian_acqf( - model: Model, - objective_weights: Tensor, - outcome_constraints: Optional[Tuple[Tensor, Tensor]] = None, - X_observed: Optional[Tensor] = None, - X_pending: Optional[Tensor] = None, - **kwargs: Any, -) -> AcquisitionFunction: - """NOTE: An `acqf_constructor` with which the underlying acquisition function - is constructed is optionally extracted from `kwargs` and defaults to NEI. - - We did not add `acqf_constructor` directly to the argument list of - `get_fully_bayesian_acqf` so that it satisfies the `TAcqfConstructor` Protocol - that is shared by all other legacy Ax acquisition function constructors. - """ - kwargs["marginalize_dim"] = -3 - acqf_constructor: TAcqfConstructor = kwargs.pop("acqf_constructor", get_qLogNEI) - acqf = acqf_constructor( - model=model, - objective_weights=objective_weights, - outcome_constraints=outcome_constraints, - X_observed=X_observed, - X_pending=X_pending, - **kwargs, - ) - base_forward = acqf.forward - # enabling manual override, default to True for LogEI - log = kwargs.pop("log", acqf._log) - sample_reduction = torch.mean if not log else logmeanexp - - # pyre-fixme[53]: Captured variable `base_forward` is not annotated. - # pyre-fixme[3]: Return type must be annotated. - # pyre-fixme[2]: Parameter must be annotated. - def forward(self, X): - # unsqueeze dim for GP hyperparameter samples - return sample_reduction(base_forward(X.unsqueeze(-3)), dim=-1) - - acqf.forward = types.MethodType(forward, acqf) # pyre-ignore[8] - return acqf - - -def get_fully_bayesian_acqf_nehvi( - model: Model, - objective_weights: Tensor, - outcome_constraints: Optional[Tuple[Tensor, Tensor]] = None, - X_observed: Optional[Tensor] = None, - X_pending: Optional[Tensor] = None, - **kwargs: Any, -) -> AcquisitionFunction: - return get_fully_bayesian_acqf( - model=model, - objective_weights=objective_weights, - outcome_constraints=outcome_constraints, - X_observed=X_observed, - X_pending=X_pending, - acqf_constructor=get_qLogNEHVI, - **kwargs, - ) - - -class FullyBayesianBotorchModelMixin: - _model: Optional[Model] = None - - def feature_importances(self) -> np.ndarray: - if self._model is None: - raise RuntimeError( - "Cannot calculate feature_importances without a fitted model" - ) - elif isinstance(self._model, ModelListGP): - models = self._model.models - else: - models = [self._model] - lengthscales = [] - for m in models: - ls = m.covar_module.base_kernel.lengthscale - lengthscales.append(ls) - lengthscales = torch.stack(lengthscales, dim=0) - # take mean over MCMC samples - lengthscales = torch.quantile(lengthscales, 0.5, dim=1) - # pyre-ignore [16] - # pyre-fixme[58]: `/` is not supported for operand types `int` and `Tensor`. - return (1 / lengthscales).detach().cpu().numpy() - - -class FullyBayesianBotorchModel(FullyBayesianBotorchModelMixin, BotorchModel): - r"""Fully Bayesian Model that uses NUTS to sample from hyperparameter posterior. - - This includes support for using sparse axis-aligned subspace priors (SAAS). See - [Eriksson2021saasbo]_ for details. - """ - - def __init__( - self, - model_constructor: TModelConstructor = get_and_fit_model_mcmc, - model_predictor: TModelPredictor = predict_from_model_mcmc, - acqf_constructor: TAcqfConstructor = get_fully_bayesian_acqf, - # pyre-fixme[9]: acqf_optimizer declared/used type mismatch - acqf_optimizer: TOptimizer = scipy_optimizer, - best_point_recommender: TBestPointRecommender = recommend_best_observed_point, - refit_on_cv: bool = False, - warm_start_refitting: bool = True, - use_input_warping: bool = False, - # use_saas is deprecated. TODO: remove - use_saas: Optional[bool] = None, - num_samples: int = 256, - warmup_steps: int = 512, - thinning: int = 16, - max_tree_depth: int = 6, - disable_progbar: bool = False, - gp_kernel: str = "matern", - verbose: bool = False, - jit_compile: bool = False, - **kwargs: Any, - ) -> None: - """Initialize Fully Bayesian Botorch Model. - - Args: - model_constructor: A callable that instantiates and fits a model on data, - with signature as described below. - model_predictor: A callable that predicts using the fitted model, with - signature as described below. - acqf_constructor: A callable that creates an acquisition function from a - fitted model, with signature as described below. - acqf_optimizer: A callable that optimizes the acquisition function, with - signature as described below. - best_point_recommender: A callable that recommends the best point, with - signature as described below. - refit_on_cv: If True, refit the model for each fold when performing - cross-validation. - warm_start_refitting: If True, start model refitting from previous - model parameters in order to speed up the fitting process. - use_input_warping: A boolean indicating whether to use input warping - use_saas: [deprecated] A boolean indicating whether to use the SAAS model - num_samples: The number of MCMC samples. Note that with thinning, - num_samples/thinning samples are retained. - warmup_steps: The number of burn-in steps for NUTS. - thinning: The amount of thinning. Every nth sample is retained. - max_tree_depth: The max_tree_depth for NUTS. - disable_progbar: A boolean indicating whether to print the progress - bar and diagnostics during MCMC. - gp_kernel: The type of ARD base kernel. "matern" corresponds to a Matern-5/2 - kernel and "rbf" corresponds to an RBF kernel. - verbose: A boolean indicating whether to print summary stats from MCMC. - """ - # use_saas is deprecated. TODO: remove - if use_saas is not None: - warnings.warn(SAAS_DEPRECATION_MSG, DeprecationWarning) - BotorchModel.__init__( - self, - model_constructor=model_constructor, - model_predictor=model_predictor, - acqf_constructor=acqf_constructor, - acqf_optimizer=acqf_optimizer, - best_point_recommender=best_point_recommender, - refit_on_cv=refit_on_cv, - warm_start_refitting=warm_start_refitting, - use_input_warping=use_input_warping, - num_samples=num_samples, - warmup_steps=warmup_steps, - thinning=thinning, - max_tree_depth=max_tree_depth, - disable_progbar=disable_progbar, - gp_kernel=gp_kernel, - verbose=verbose, - jit_compile=jit_compile, - ) - - -class FullyBayesianMOOBotorchModel( - FullyBayesianBotorchModelMixin, MultiObjectiveBotorchModel -): - r"""Fully Bayesian Model that uses qNEHVI. - - This includes support for using qNEHVI + SAASBO as in [Eriksson2021nas]_. - """ - - @copy_doc(FullyBayesianBotorchModel.__init__) - def __init__( - self, - model_constructor: TModelConstructor = get_and_fit_model_mcmc, - model_predictor: TModelPredictor = predict_from_model_mcmc, - acqf_constructor: TAcqfConstructor = get_fully_bayesian_acqf_nehvi, - # pyre-fixme[9]: acqf_optimizer has type `Callable[[AcquisitionFunction, - # Tensor, int, Optional[Dict[int, float]], Optional[Callable[[Tensor], - # Tensor]], Any], Tensor]`; used as `Callable[[AcquisitionFunction, Tensor, - # int, Optional[Dict[int, float]], Optional[Callable[[Tensor], Tensor]], - # **(Any)], Tensor]`. - acqf_optimizer: TOptimizer = scipy_optimizer, - # TODO: Remove best_point_recommender for botorch_moo. Used in modelbridge._gen. - best_point_recommender: TBestPointRecommender = recommend_best_observed_point, - frontier_evaluator: TFrontierEvaluator = pareto_frontier_evaluator, - refit_on_cv: bool = False, - warm_start_refitting: bool = False, - use_input_warping: bool = False, - num_samples: int = 256, - warmup_steps: int = 512, - thinning: int = 16, - max_tree_depth: int = 6, - # use_saas is deprecated. TODO: remove - use_saas: Optional[bool] = None, - disable_progbar: bool = False, - gp_kernel: str = "matern", - verbose: bool = False, - jit_compile: bool = False, - **kwargs: Any, - ) -> None: - # use_saas is deprecated. TODO: remove - if use_saas is not None: - warnings.warn(SAAS_DEPRECATION_MSG, DeprecationWarning) - MultiObjectiveBotorchModel.__init__( - self, - model_constructor=model_constructor, - model_predictor=model_predictor, - acqf_constructor=acqf_constructor, - acqf_optimizer=acqf_optimizer, - best_point_recommender=best_point_recommender, - frontier_evaluator=frontier_evaluator, - refit_on_cv=refit_on_cv, - warm_start_refitting=warm_start_refitting, - use_input_warping=use_input_warping, - num_samples=num_samples, - warmup_steps=warmup_steps, - thinning=thinning, - max_tree_depth=max_tree_depth, - disable_progbar=disable_progbar, - gp_kernel=gp_kernel, - verbose=verbose, - jit_compile=jit_compile, - ) diff --git a/ax/models/torch/fully_bayesian_model_utils.py b/ax/models/torch/fully_bayesian_model_utils.py deleted file mode 100644 index 3338c6fb40f..00000000000 --- a/ax/models/torch/fully_bayesian_model_utils.py +++ /dev/null @@ -1,220 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. - -# pyre-strict - -from typing import Any, Dict, List, Optional, Tuple - -import pyro -import torch -from ax.models.torch.botorch_defaults import _get_model -from botorch.models.gp_regression import MIN_INFERRED_NOISE_LEVEL -from botorch.models.gpytorch import GPyTorchModel -from botorch.models.model_list_gp_regression import ModelListGP -from gpytorch.kernels import MaternKernel, RBFKernel, ScaleKernel -from torch import Tensor - - -def _get_rbf_kernel(num_samples: int, dim: int) -> ScaleKernel: - return ScaleKernel( - base_kernel=RBFKernel(ard_num_dims=dim, batch_shape=torch.Size([num_samples])), - batch_shape=torch.Size([num_samples]), - ) - - -def _get_matern_kernel(num_samples: int, dim: int) -> ScaleKernel: - return ScaleKernel( - base_kernel=MaternKernel( - ard_num_dims=dim, batch_shape=torch.Size([num_samples]) - ), - batch_shape=torch.Size([num_samples]), - ) - - -def _get_single_task_gpytorch_model( - Xs: List[Tensor], - Ys: List[Tensor], - Yvars: List[Tensor], - task_features: List[int], - fidelity_features: List[int], - state_dict: Optional[Dict[str, Tensor]] = None, - num_samples: int = 512, - thinning: int = 16, - use_input_warping: bool = False, - gp_kernel: str = "matern", - **kwargs: Any, -) -> ModelListGP: - r"""Instantiates a batched GPyTorchModel(ModelListGP) based on the given data. - The model fitting is based on MCMC and is run separately using pyro. The MCMC - samples will be loaded into the model instantiated here afterwards. - - Returns: - A ModelListGP. - """ - if len(task_features) > 0: - raise NotImplementedError("Currently do not support MT-GP models with MCMC!") - if len(fidelity_features) > 0: - raise NotImplementedError( - "Fidelity MF-GP models are not currently supported with MCMC!" - ) - - num_mcmc_samples = num_samples // thinning - covar_modules = [ - ( - _get_rbf_kernel(num_samples=num_mcmc_samples, dim=Xs[0].shape[-1]) - if gp_kernel == "rbf" - else _get_matern_kernel(num_samples=num_mcmc_samples, dim=Xs[0].shape[-1]) - ) - for _ in range(len(Xs)) - ] - - models = [ - _get_model( - X=X.unsqueeze(0).expand(num_mcmc_samples, X.shape[0], -1), - Y=Y.unsqueeze(0).expand(num_mcmc_samples, Y.shape[0], -1), - Yvar=Yvar.unsqueeze(0).expand(num_mcmc_samples, Yvar.shape[0], -1), - fidelity_features=fidelity_features, - use_input_warping=use_input_warping, - covar_module=covar_module, - **kwargs, - ) - for X, Y, Yvar, covar_module in zip(Xs, Ys, Yvars, covar_modules) - ] - model = ModelListGP(*models) - model.to(Xs[0]) - return model - - -def pyro_sample_outputscale( - concentration: float = 2.0, - rate: float = 0.15, - **tkwargs: Any, -) -> Tensor: - - return pyro.sample( - "outputscale", - # pyre-fixme[16]: Module `distributions` has no attribute `Gamma` - pyro.distributions.Gamma( - torch.tensor(concentration, **tkwargs), - torch.tensor(rate, **tkwargs), - ), - ) - - -def pyro_sample_mean(**tkwargs: Any) -> Tensor: - - return pyro.sample( - "mean", - # pyre-fixme[16]: Module `distributions` has no attribute `Normal`. - pyro.distributions.Normal( - torch.tensor(0.0, **tkwargs), - torch.tensor(1.0, **tkwargs), - ), - ) - - -def pyro_sample_noise(**tkwargs: Any) -> Tensor: - - # this prefers small noise but has heavy tails - return pyro.sample( - "noise", - # pyre-fixme[16]: Module `distributions` has no attribute `Gamma`. - pyro.distributions.Gamma( - torch.tensor(0.9, **tkwargs), - torch.tensor(10.0, **tkwargs), - ), - ) - - -def pyro_sample_saas_lengthscales( - dim: int, - alpha: float = 0.1, - **tkwargs: Any, -) -> Tensor: - - tausq = pyro.sample( - "kernel_tausq", - # pyre-fixme[16]: Module `distributions` has no attribute `HalfCauchy`. - pyro.distributions.HalfCauchy(torch.tensor(alpha, **tkwargs)), - ) - inv_length_sq = pyro.sample( - "_kernel_inv_length_sq", - # pyre-fixme[16]: Module `distributions` has no attribute `HalfCauchy`. - pyro.distributions.HalfCauchy(torch.ones(dim, **tkwargs)), - ) - inv_length_sq = pyro.deterministic("kernel_inv_length_sq", tausq * inv_length_sq) - lengthscale = pyro.deterministic( - "lengthscale", - (1.0 / inv_length_sq).sqrt(), # pyre-ignore [16] - ) - return lengthscale - - -def pyro_sample_input_warping( - dim: int, - **tkwargs: Any, -) -> Tuple[Tensor, Tensor]: - - c0 = pyro.sample( - "c0", - # pyre-fixme[16]: Module `distributions` has no attribute `LogNormal`. - pyro.distributions.LogNormal( - torch.tensor([0.0] * dim, **tkwargs), - torch.tensor([0.75**0.5] * dim, **tkwargs), - ), - ) - c1 = pyro.sample( - "c1", - # pyre-fixme[16]: Module `distributions` has no attribute `LogNormal`. - pyro.distributions.LogNormal( - torch.tensor([0.0] * dim, **tkwargs), - torch.tensor([0.75**0.5] * dim, **tkwargs), - ), - ) - return c0, c1 - - -# pyre-fixme[24]: Generic type `dict` expects 2 type parameters, use `typing.Dict` -# to avoid runtime subscripting errors. -def load_mcmc_samples_to_model(model: GPyTorchModel, mcmc_samples: Dict) -> None: - """Load MCMC samples into GPyTorchModel.""" - if "noise" in mcmc_samples: - model.likelihood.noise_covar.noise = ( - mcmc_samples["noise"] - .detach() - .clone() - .view(model.likelihood.noise_covar.noise.shape) - .clamp_min(MIN_INFERRED_NOISE_LEVEL) - ) - model.covar_module.base_kernel.lengthscale = ( - mcmc_samples["lengthscale"] - .detach() - .clone() - .view(model.covar_module.base_kernel.lengthscale.shape) - ) - model.covar_module.outputscale = ( - mcmc_samples["outputscale"] - .detach() - .clone() - .view(model.covar_module.outputscale.shape) - ) - model.mean_module.constant.data = ( - mcmc_samples["mean"].detach().clone().view(model.mean_module.constant.shape) - ) - if "c0" in mcmc_samples: - model.input_transform._set_concentration( - i=0, - value=mcmc_samples["c0"] - .detach() - .clone() - .view(model.input_transform.concentration0.shape), - ) - model.input_transform._set_concentration( - i=1, - value=mcmc_samples["c1"] - .detach() - .clone() - .view(model.input_transform.concentration1.shape), - ) diff --git a/ax/utils/testing/mock.py b/ax/utils/testing/mock.py index f3c30ec7f36..ed43bd41f32 100644 --- a/ax/utils/testing/mock.py +++ b/ax/utils/testing/mock.py @@ -10,7 +10,6 @@ from typing import Any, Callable, Dict, Generator, Optional from unittest import mock -from ax.models.torch.fully_bayesian import run_inference from botorch.fit import fit_fully_bayesian_model_nuts from botorch.generation.gen import minimize_with_timeout from botorch.optim.initializers import ( @@ -54,9 +53,6 @@ def minimal_gen_os_ics(*args: Any, **kwargs: Any) -> Optional[Tensor]: return gen_one_shot_kg_initial_conditions(*args, **kwargs) - def minimal_run_inference(*args: Any, **kwargs: Any) -> Dict[str, Tensor]: - return run_inference(*args, **_get_minimal_mcmc_kwargs(**kwargs)) - def minimal_fit_fully_bayesian(*args: Any, **kwargs: Any) -> None: fit_fully_bayesian_model_nuts(*args, **_get_minimal_mcmc_kwargs(**kwargs)) @@ -89,13 +85,6 @@ def minimal_fit_fully_bayesian(*args: Any, **kwargs: Any) -> None: ) ) - mock_mcmc_legacy = es.enter_context( - mock.patch( - "ax.models.torch.fully_bayesian.run_inference", - wraps=minimal_run_inference, - ) - ) - mock_mcmc_mbm = es.enter_context( mock.patch( "ax.models.torch.botorch_modular.utils.fit_fully_bayesian_model_nuts", @@ -112,7 +101,6 @@ def minimal_fit_fully_bayesian(*args: Any, **kwargs: Any) -> None: mock_fit, mock_gen_ics, mock_gen_os_ics, - mock_mcmc_legacy, mock_mcmc_mbm, ] ):