Skip to content

Commit

Permalink
Merge pull request #1043 from borglab/feature/more_discrete_methods
Browse files Browse the repository at this point in the history
  • Loading branch information
dellaert committed Jan 18, 2022
2 parents a74da73 + 10e1bd2 commit 296b469
Show file tree
Hide file tree
Showing 4 changed files with 137 additions and 72 deletions.
126 changes: 66 additions & 60 deletions gtsam/discrete/DiscreteBayesNet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,65 +25,71 @@

namespace gtsam {

// Instantiate base class
template class FactorGraph<DiscreteConditional>;

/* ************************************************************************* */
bool DiscreteBayesNet::equals(const This& bn, double tol) const
{
return Base::equals(bn, tol);
}

/* ************************************************************************* */
double DiscreteBayesNet::evaluate(const DiscreteValues & values) const {
// evaluate all conditionals and multiply
double result = 1.0;
for(const DiscreteConditional::shared_ptr& conditional: *this)
result *= (*conditional)(values);
return result;
}

/* ************************************************************************* */
DiscreteValues DiscreteBayesNet::optimize() const {
// solve each node in turn in topological sort order (parents first)
DiscreteValues result;
for (auto conditional: boost::adaptors::reverse(*this))
conditional->solveInPlace(&result);
return result;
}

/* ************************************************************************* */
DiscreteValues DiscreteBayesNet::sample() const {
// sample each node in turn in topological sort order (parents first)
DiscreteValues result;
for (auto conditional: boost::adaptors::reverse(*this))
conditional->sampleInPlace(&result);
return result;
}

/* *********************************************************************** */
std::string DiscreteBayesNet::markdown(
const KeyFormatter& keyFormatter,
const DiscreteFactor::Names& names) const {
using std::endl;
std::stringstream ss;
ss << "`DiscreteBayesNet` of size " << size() << endl << endl;
for (const DiscreteConditional::shared_ptr& conditional : *this)
ss << conditional->markdown(keyFormatter, names) << endl;
return ss.str();
}

/* *********************************************************************** */
std::string DiscreteBayesNet::html(
const KeyFormatter& keyFormatter,
const DiscreteFactor::Names& names) const {
using std::endl;
std::stringstream ss;
ss << "<div><p><tt>DiscreteBayesNet</tt> of size " << size() << "</p>";
for (const DiscreteConditional::shared_ptr& conditional : *this)
ss << conditional->html(keyFormatter, names) << endl;
return ss.str();
}
// Instantiate base class
template class FactorGraph<DiscreteConditional>;

/* ************************************************************************* */
} // namespace
bool DiscreteBayesNet::equals(const This& bn, double tol) const {
return Base::equals(bn, tol);
}

/* ************************************************************************* */
double DiscreteBayesNet::evaluate(const DiscreteValues& values) const {
// evaluate all conditionals and multiply
double result = 1.0;
for (const DiscreteConditional::shared_ptr& conditional : *this)
result *= (*conditional)(values);
return result;
}

/* ************************************************************************* */
DiscreteValues DiscreteBayesNet::optimize() const {
DiscreteValues result;
return optimize(result);
}

DiscreteValues DiscreteBayesNet::optimize(DiscreteValues result) const {
// solve each node in turn in topological sort order (parents first)
for (auto conditional : boost::adaptors::reverse(*this))
conditional->solveInPlace(&result);
return result;
}

/* ************************************************************************* */
DiscreteValues DiscreteBayesNet::sample() const {
DiscreteValues result;
return sample(result);
}

DiscreteValues DiscreteBayesNet::sample(DiscreteValues result) const {
// sample each node in turn in topological sort order (parents first)
for (auto conditional : boost::adaptors::reverse(*this))
conditional->sampleInPlace(&result);
return result;
}

/* *********************************************************************** */
std::string DiscreteBayesNet::markdown(
const KeyFormatter& keyFormatter,
const DiscreteFactor::Names& names) const {
using std::endl;
std::stringstream ss;
ss << "`DiscreteBayesNet` of size " << size() << endl << endl;
for (const DiscreteConditional::shared_ptr& conditional : *this)
ss << conditional->markdown(keyFormatter, names) << endl;
return ss.str();
}

/* *********************************************************************** */
std::string DiscreteBayesNet::html(const KeyFormatter& keyFormatter,
const DiscreteFactor::Names& names) const {
using std::endl;
std::stringstream ss;
ss << "<div><p><tt>DiscreteBayesNet</tt> of size " << size() << "</p>";
for (const DiscreteConditional::shared_ptr& conditional : *this)
ss << conditional->html(keyFormatter, names) << endl;
return ss.str();
}

/* ************************************************************************* */
} // namespace gtsam
38 changes: 36 additions & 2 deletions gtsam/discrete/DiscreteBayesNet.h
Original file line number Diff line number Diff line change
Expand Up @@ -99,13 +99,47 @@ namespace gtsam {
}

/**
* Solve the DiscreteBayesNet by back-substitution
* @brief solve by back-substitution.
*
* Assumes the Bayes net is reverse topologically sorted, i.e. last
* conditional will be optimized first. If the Bayes net resulted from
* eliminating a factor graph, this is true for the elimination ordering.
*
* @return a sampled value for all variables.
*/
DiscreteValues optimize() const;

/** Do ancestral sampling */
/**
* @brief solve by back-substitution, given certain variables.
*
* Assumes the Bayes net is reverse topologically sorted *and* that the
* Bayes net does not contain any conditionals for the given values.
*
* @return given values extended with optimized value for other variables.
*/
DiscreteValues optimize(DiscreteValues given) const;

/**
* @brief do ancestral sampling
*
* Assumes the Bayes net is reverse topologically sorted, i.e. last
* conditional will be sampled first. If the Bayes net resulted from
* eliminating a factor graph, this is true for the elimination ordering.
*
* @return a sampled value for all variables.
*/
DiscreteValues sample() const;

/**
* @brief do ancestral sampling, given certain variables.
*
* Assumes the Bayes net is reverse topologically sorted *and* that the
* Bayes net does not contain any conditionals for the given values.
*
* @return given values extended with sampled value for all other variables.
*/
DiscreteValues sample(DiscreteValues given) const;

///@}
/// @name Wrapper support
/// @{
Expand Down
2 changes: 2 additions & 0 deletions gtsam/discrete/discrete.i
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,9 @@ class DiscreteBayesNet {
gtsam::DefaultKeyFormatter) const;
double operator()(const gtsam::DiscreteValues& values) const;
gtsam::DiscreteValues optimize() const;
gtsam::DiscreteValues optimize(gtsam::DiscreteValues given) const;
gtsam::DiscreteValues sample() const;
gtsam::DiscreteValues sample(gtsam::DiscreteValues given) const;
string markdown(const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
string markdown(const gtsam::KeyFormatter& keyFormatter,
Expand Down
43 changes: 33 additions & 10 deletions python/gtsam/tests/test_DiscreteBayesNet.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,17 @@
DiscreteKeys, DiscreteDistribution, DiscreteValues, Ordering)
from gtsam.utils.test_case import GtsamTestCase

# Some keys:
Asia = (0, 2)
Smoking = (4, 2)
Tuberculosis = (3, 2)
LungCancer = (6, 2)

Bronchitis = (7, 2)
Either = (5, 2)
XRay = (2, 2)
Dyspnea = (1, 2)


class TestDiscreteBayesNet(GtsamTestCase):
"""Tests for Discrete Bayes Nets."""
Expand All @@ -43,16 +54,6 @@ def test_constructor(self):
def test_Asia(self):
"""Test full Asia example."""

Asia = (0, 2)
Smoking = (4, 2)
Tuberculosis = (3, 2)
LungCancer = (6, 2)

Bronchitis = (7, 2)
Either = (5, 2)
XRay = (2, 2)
Dyspnea = (1, 2)

asia = DiscreteBayesNet()
asia.add(Asia, "99/1")
asia.add(Smoking, "50/50")
Expand Down Expand Up @@ -107,6 +108,28 @@ def test_Asia(self):
actualSample = chordal2.sample()
self.assertEqual(len(actualSample), 8)

def test_fragment(self):
"""Test sampling and optimizing for Asia fragment."""

# Create a reverse-topologically sorted fragment:
fragment = DiscreteBayesNet()
fragment.add(Either, [Tuberculosis, LungCancer], "F T T T")
fragment.add(Tuberculosis, [Asia], "99/1 95/5")
fragment.add(LungCancer, [Smoking], "99/1 90/10")

# Create assignment with missing values:
given = DiscreteValues()
for key in [Asia, Smoking]:
given[key[0]] = 0

# Now optimize fragment:
actual = fragment.optimize(given)
self.assertEqual(len(actual), 5)

# Now sample from fragment:
actual = fragment.sample(given)
self.assertEqual(len(actual), 5)


if __name__ == "__main__":
unittest.main()

0 comments on commit 296b469

Please sign in to comment.