Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Bayesian Additive Regression Trees (BARTs) #4183

Merged
merged 26 commits into from
Nov 14, 2020

Conversation

aloctavodia
Copy link
Member

This add BARTs to PyMC3 at the API level this looks (almost) like a new distribution (more on this below). Additionally this add a new sampler specifically designed for BART, the PGBART sampler, this is necessary because trees are a very particular kind of discrete distribution (we can also see them as stepwise functions).

The general idea of BARTs is that given a problem of the form y = f(X), we can approximate the unknown function f as a sum of m trees. As trees can easily overfit, BARTs put priors over trees to make each tree only capable of explain a little bit of the data (for example trees tend to be shallow) and thus we must sum many trees to get a reasonable good approximation.

A 1D example.

with pm.Model() as model:
    σ = pm.HalfNormal('σ', 1)
    μ = pm.BART('μ', X, Y, m=50)
    y = pm.Normal('y', μ, σ, observed=Y)

BART_simple_linear_regression_new

The black line is the mean of μ and the band the HDI of μ. As you can see the mean is not a smooth curve because trees are discrete. Notice that this mean is a sum of 50 trees over 2000 posterior draws (2 chains each one of 1000 draws)

This work is the continuation of what @jmloyola did for the GSOC. The main differences are that I reduce most of the trees-code to the essential parts, I try to speed-up things (probably there is still room for improvement) and mainly that I focused on trying to make BART to work inside a probabilistic programming language (I mention this because there is a family of BART methods, in general they are designed with an specific likelihood in mind, and thus they rely on conjugancy). My goal for the BART implementation in PyMC (this will need more PRs) is that BART becomes as flexible as any other distribution, so it can be combined with other distribution to create arbitrary models. At the moment its parameters m and alpha must be floats, not distributions, the main reason is that this is generally the case. There are some reports in the literature saying that putting priors on top of that parameters does not work computationally very well, but this is something I would like to explore.

Some missing features I will like to work on future PRs: Variable selection methods, better test, documentation, store info that could be used for diagnostics. And do some research to better grasp how it behaves for real/complex datasets and way to better select its parameters (loo, CV, priors...)

vars: list
List of variables for sampler
num_particles : int
Number of particles for the SMC sampler. Defaults to 10
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SMC -> PGBART, same below

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried to follow the nomenclature in the papers and one step of the PGBART is "conditional-SMC" method. But I see how this can be confusing.

return particles


class Particle:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
class Particle:
class ParticleTree:

@junpenglao
Copy link
Member

I will do a full review next week, but just want to say congrats and looking forward to trying this out!

@junpenglao junpenglao self-assigned this Oct 23, 2020
@codecov
Copy link

codecov bot commented Oct 23, 2020

Codecov Report

Merging #4183 (5a7b552) into master (f732a01) will increase coverage by 0.02%.
The diff coverage is 90.02%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master    #4183      +/-   ##
==========================================
+ Coverage   88.91%   88.93%   +0.02%     
==========================================
  Files          89       92       +3     
  Lines       14429    14788     +359     
==========================================
+ Hits        12829    13152     +323     
- Misses       1600     1636      +36     
Impacted Files Coverage Δ
pymc3/distributions/bart.py 80.80% <80.80%> (ø)
pymc3/distributions/tree.py 88.60% <88.60%> (ø)
pymc3/step_methods/pgbart.py 97.98% <97.98%> (ø)
pymc3/distributions/__init__.py 100.00% <100.00%> (ø)
pymc3/model.py 89.33% <100.00%> (ø)
pymc3/sampling.py 86.88% <100.00%> (+0.04%) ⬆️
pymc3/step_methods/__init__.py 100.00% <100.00%> (ø)
pymc3/step_methods/hmc/nuts.py 97.48% <100.00%> (+0.01%) ⬆️
... and 1 more

@jlevy44
Copy link

jlevy44 commented Nov 11, 2020

Very impactful work, thank you for doing this!

Copy link
Member

@junpenglao junpenglao left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Dont want to block this so feel free to merge when tests passed.
I will test it out on master and file bug if problem arise.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants