From 10e1bd2f6167bcf667090564bf287a1ee492f6e0 Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Mon, 17 Jan 2022 22:59:17 -0500 Subject: [PATCH] sample variants --- gtsam/discrete/DiscreteBayesNet.cpp | 126 ++++++++++---------- gtsam/discrete/DiscreteBayesNet.h | 38 +++++- gtsam/discrete/discrete.i | 2 + python/gtsam/tests/test_DiscreteBayesNet.py | 43 +++++-- 4 files changed, 137 insertions(+), 72 deletions(-) diff --git a/gtsam/discrete/DiscreteBayesNet.cpp b/gtsam/discrete/DiscreteBayesNet.cpp index c0dfd747c3..7294c8b296 100644 --- a/gtsam/discrete/DiscreteBayesNet.cpp +++ b/gtsam/discrete/DiscreteBayesNet.cpp @@ -25,65 +25,71 @@ namespace gtsam { - // Instantiate base class - template class FactorGraph; - - /* ************************************************************************* */ - bool DiscreteBayesNet::equals(const This& bn, double tol) const - { - return Base::equals(bn, tol); - } - - /* ************************************************************************* */ - double DiscreteBayesNet::evaluate(const DiscreteValues & values) const { - // evaluate all conditionals and multiply - double result = 1.0; - for(const DiscreteConditional::shared_ptr& conditional: *this) - result *= (*conditional)(values); - return result; - } - - /* ************************************************************************* */ - DiscreteValues DiscreteBayesNet::optimize() const { - // solve each node in turn in topological sort order (parents first) - DiscreteValues result; - for (auto conditional: boost::adaptors::reverse(*this)) - conditional->solveInPlace(&result); - return result; - } - - /* ************************************************************************* */ - DiscreteValues DiscreteBayesNet::sample() const { - // sample each node in turn in topological sort order (parents first) - DiscreteValues result; - for (auto conditional: boost::adaptors::reverse(*this)) - conditional->sampleInPlace(&result); - return result; - } - - /* *********************************************************************** */ - std::string DiscreteBayesNet::markdown( - const KeyFormatter& keyFormatter, - const DiscreteFactor::Names& names) const { - using std::endl; - std::stringstream ss; - ss << "`DiscreteBayesNet` of size " << size() << endl << endl; - for (const DiscreteConditional::shared_ptr& conditional : *this) - ss << conditional->markdown(keyFormatter, names) << endl; - return ss.str(); - } - - /* *********************************************************************** */ - std::string DiscreteBayesNet::html( - const KeyFormatter& keyFormatter, - const DiscreteFactor::Names& names) const { - using std::endl; - std::stringstream ss; - ss << "

DiscreteBayesNet of size " << size() << "

"; - for (const DiscreteConditional::shared_ptr& conditional : *this) - ss << conditional->html(keyFormatter, names) << endl; - return ss.str(); - } +// Instantiate base class +template class FactorGraph; /* ************************************************************************* */ -} // namespace +bool DiscreteBayesNet::equals(const This& bn, double tol) const { + return Base::equals(bn, tol); +} + +/* ************************************************************************* */ +double DiscreteBayesNet::evaluate(const DiscreteValues& values) const { + // evaluate all conditionals and multiply + double result = 1.0; + for (const DiscreteConditional::shared_ptr& conditional : *this) + result *= (*conditional)(values); + return result; +} + +/* ************************************************************************* */ +DiscreteValues DiscreteBayesNet::optimize() const { + DiscreteValues result; + return optimize(result); +} + +DiscreteValues DiscreteBayesNet::optimize(DiscreteValues result) const { + // solve each node in turn in topological sort order (parents first) + for (auto conditional : boost::adaptors::reverse(*this)) + conditional->solveInPlace(&result); + return result; +} + +/* ************************************************************************* */ +DiscreteValues DiscreteBayesNet::sample() const { + DiscreteValues result; + return sample(result); +} + +DiscreteValues DiscreteBayesNet::sample(DiscreteValues result) const { + // sample each node in turn in topological sort order (parents first) + for (auto conditional : boost::adaptors::reverse(*this)) + conditional->sampleInPlace(&result); + return result; +} + +/* *********************************************************************** */ +std::string DiscreteBayesNet::markdown( + const KeyFormatter& keyFormatter, + const DiscreteFactor::Names& names) const { + using std::endl; + std::stringstream ss; + ss << "`DiscreteBayesNet` of size " << size() << endl << endl; + for (const DiscreteConditional::shared_ptr& conditional : *this) + ss << conditional->markdown(keyFormatter, names) << endl; + return ss.str(); +} + +/* *********************************************************************** */ +std::string DiscreteBayesNet::html(const KeyFormatter& keyFormatter, + const DiscreteFactor::Names& names) const { + using std::endl; + std::stringstream ss; + ss << "

DiscreteBayesNet of size " << size() << "

"; + for (const DiscreteConditional::shared_ptr& conditional : *this) + ss << conditional->html(keyFormatter, names) << endl; + return ss.str(); +} + +/* ************************************************************************* */ +} // namespace gtsam diff --git a/gtsam/discrete/DiscreteBayesNet.h b/gtsam/discrete/DiscreteBayesNet.h index db20e7223a..bd5536135a 100644 --- a/gtsam/discrete/DiscreteBayesNet.h +++ b/gtsam/discrete/DiscreteBayesNet.h @@ -99,13 +99,47 @@ namespace gtsam { } /** - * Solve the DiscreteBayesNet by back-substitution + * @brief solve by back-substitution. + * + * Assumes the Bayes net is reverse topologically sorted, i.e. last + * conditional will be optimized first. If the Bayes net resulted from + * eliminating a factor graph, this is true for the elimination ordering. + * + * @return a sampled value for all variables. */ DiscreteValues optimize() const; - /** Do ancestral sampling */ + /** + * @brief solve by back-substitution, given certain variables. + * + * Assumes the Bayes net is reverse topologically sorted *and* that the + * Bayes net does not contain any conditionals for the given values. + * + * @return given values extended with optimized value for other variables. + */ + DiscreteValues optimize(DiscreteValues given) const; + + /** + * @brief do ancestral sampling + * + * Assumes the Bayes net is reverse topologically sorted, i.e. last + * conditional will be sampled first. If the Bayes net resulted from + * eliminating a factor graph, this is true for the elimination ordering. + * + * @return a sampled value for all variables. + */ DiscreteValues sample() const; + /** + * @brief do ancestral sampling, given certain variables. + * + * Assumes the Bayes net is reverse topologically sorted *and* that the + * Bayes net does not contain any conditionals for the given values. + * + * @return given values extended with sampled value for all other variables. + */ + DiscreteValues sample(DiscreteValues given) const; + ///@} /// @name Wrapper support /// @{ diff --git a/gtsam/discrete/discrete.i b/gtsam/discrete/discrete.i index 7ce4bd9021..e4af27eb19 100644 --- a/gtsam/discrete/discrete.i +++ b/gtsam/discrete/discrete.i @@ -165,7 +165,9 @@ class DiscreteBayesNet { gtsam::DefaultKeyFormatter) const; double operator()(const gtsam::DiscreteValues& values) const; gtsam::DiscreteValues optimize() const; + gtsam::DiscreteValues optimize(gtsam::DiscreteValues given) const; gtsam::DiscreteValues sample() const; + gtsam::DiscreteValues sample(gtsam::DiscreteValues given) const; string markdown(const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter) const; string markdown(const gtsam::KeyFormatter& keyFormatter, diff --git a/python/gtsam/tests/test_DiscreteBayesNet.py b/python/gtsam/tests/test_DiscreteBayesNet.py index 36f0d153d9..6abd660cfc 100644 --- a/python/gtsam/tests/test_DiscreteBayesNet.py +++ b/python/gtsam/tests/test_DiscreteBayesNet.py @@ -17,6 +17,17 @@ DiscreteKeys, DiscreteDistribution, DiscreteValues, Ordering) from gtsam.utils.test_case import GtsamTestCase +# Some keys: +Asia = (0, 2) +Smoking = (4, 2) +Tuberculosis = (3, 2) +LungCancer = (6, 2) + +Bronchitis = (7, 2) +Either = (5, 2) +XRay = (2, 2) +Dyspnea = (1, 2) + class TestDiscreteBayesNet(GtsamTestCase): """Tests for Discrete Bayes Nets.""" @@ -43,16 +54,6 @@ def test_constructor(self): def test_Asia(self): """Test full Asia example.""" - Asia = (0, 2) - Smoking = (4, 2) - Tuberculosis = (3, 2) - LungCancer = (6, 2) - - Bronchitis = (7, 2) - Either = (5, 2) - XRay = (2, 2) - Dyspnea = (1, 2) - asia = DiscreteBayesNet() asia.add(Asia, "99/1") asia.add(Smoking, "50/50") @@ -107,6 +108,28 @@ def test_Asia(self): actualSample = chordal2.sample() self.assertEqual(len(actualSample), 8) + def test_fragment(self): + """Test sampling and optimizing for Asia fragment.""" + + # Create a reverse-topologically sorted fragment: + fragment = DiscreteBayesNet() + fragment.add(Either, [Tuberculosis, LungCancer], "F T T T") + fragment.add(Tuberculosis, [Asia], "99/1 95/5") + fragment.add(LungCancer, [Smoking], "99/1 90/10") + + # Create assignment with missing values: + given = DiscreteValues() + for key in [Asia, Smoking]: + given[key[0]] = 0 + + # Now optimize fragment: + actual = fragment.optimize(given) + self.assertEqual(len(actual), 5) + + # Now sample from fragment: + actual = fragment.sample(given) + self.assertEqual(len(actual), 5) + if __name__ == "__main__": unittest.main()