Skip to content

Commit

Permalink
Bring back BART to V4 and make it more general (#4914)
Browse files Browse the repository at this point in the history
* frowardporting from unreleased v3 plus generalization

* aesarize

* improve docstrings

* small fix docstring and variable names

* fix format variable importance

* fix broadcasting issue and other minor fixes

* add test and fix pylint

* fix float32

* sample splitting variables non-uniformly

* remove xfail

* add back xfail on windows

* add back xfail on windows and for float32

* fix test

* clean rnd

* add size argument and check for NoDistribution

* stop updating split_prior after tuning

* clean code and small speed-up

* clean code and small speed-up

* revert xfail

* add tests

* fix number of chains

* revert test

* clean code, refactor and small speed-up

* test random

* test random

* add missing data test
  • Loading branch information
aloctavodia authored Sep 3, 2021
1 parent 61fa834 commit ff45994
Show file tree
Hide file tree
Showing 8 changed files with 563 additions and 379 deletions.
336 changes: 96 additions & 240 deletions pymc3/distributions/bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,271 +14,127 @@

import numpy as np

from pandas import DataFrame, Series
from aesara.tensor.random.op import RandomVariable, default_shape_from_params

from pymc3.distributions.distribution import NoDistribution
from pymc3.distributions.tree import LeafNode, SplitNode, Tree

__all__ = ["BART"]


class BaseBART(NoDistribution):
def __init__(self, X, Y, m=200, alpha=0.25, split_prior=None, *args, **kwargs):

self.X, self.Y, self.missing_data = self.preprocess_XY(X, Y)

super().__init__(shape=X.shape[0], dtype="float64", initval=0, *args, **kwargs)

if self.X.ndim != 2:
raise ValueError("The design matrix X must have two dimensions")

if self.Y.ndim != 1:
raise ValueError("The response matrix Y must have one dimension")
if self.X.shape[0] != self.Y.shape[0]:
raise ValueError(
"The design matrix X and the response matrix Y must have the same number of elements"
)
if not isinstance(m, int):
raise ValueError("The number of trees m type must be int")
if m < 1:
raise ValueError("The number of trees m must be greater than zero")

if alpha <= 0 or 1 <= alpha:
raise ValueError(
"The value for the alpha parameter for the tree structure "
"must be in the interval (0, 1)"
)

self.num_observations = X.shape[0]
self.num_variates = X.shape[1]
self.available_predictors = list(range(self.num_variates))
self.ssv = SampleSplittingVariable(split_prior, self.num_variates)
self.m = m
self.alpha = alpha
self.trees = self.init_list_of_trees()
self.all_trees = []
self.mean = fast_mean()
self.prior_prob_leaf_node = compute_prior_probability(alpha)

def preprocess_XY(self, X, Y):
if isinstance(Y, (Series, DataFrame)):
Y = Y.to_numpy()
if isinstance(X, (Series, DataFrame)):
X = X.to_numpy()
missing_data = np.any(np.isnan(X))
X = np.random.normal(X, np.std(X, 0) / 100)
return X, Y, missing_data

def init_list_of_trees(self):
initial_value_leaf_nodes = self.Y.mean() / self.m
initial_idx_data_points_leaf_nodes = np.array(range(self.num_observations), dtype="int32")
list_of_trees = []
for i in range(self.m):
new_tree = Tree.init_tree(
tree_id=i,
leaf_node_value=initial_value_leaf_nodes,
idx_data_points=initial_idx_data_points_leaf_nodes,
)
list_of_trees.append(new_tree)
# Diff trick to speed computation of residuals. From Section 3.1 of Kapelner, A and Bleich, J.
# bartMachine: A Powerful Tool for Machine Learning in R. ArXiv e-prints, 2013
# The sum_trees_output will contain the sum of the predicted output for all trees.
# When R_j is needed we subtract the current predicted output for tree T_j.
self.sum_trees_output = np.full_like(self.Y, self.Y.mean())

return list_of_trees

def __iter__(self):
return iter(self.trees)

def __repr_latex(self):
raise NotImplementedError

def get_available_splitting_rules(self, idx_data_points_split_node, idx_split_variable):
x_j = self.X[idx_data_points_split_node, idx_split_variable]
if self.missing_data:
x_j = x_j[~np.isnan(x_j)]
values = np.unique(x_j)
# The last value is never available as it would leave the right subtree empty.
return values[:-1]

def grow_tree(self, tree, index_leaf_node):
current_node = tree.get_node(index_leaf_node)

index_selected_predictor = self.ssv.rvs()
selected_predictor = self.available_predictors[index_selected_predictor]
available_splitting_rules = self.get_available_splitting_rules(
current_node.idx_data_points, selected_predictor
)
# This can be unsuccessful when there are not available splitting rules
if available_splitting_rules.size == 0:
return False, None

index_selected_splitting_rule = discrete_uniform_sampler(len(available_splitting_rules))
selected_splitting_rule = available_splitting_rules[index_selected_splitting_rule]
new_split_node = SplitNode(
index=index_leaf_node,
idx_split_variable=selected_predictor,
split_value=selected_splitting_rule,
)

left_node_idx_data_points, right_node_idx_data_points = self.get_new_idx_data_points(
new_split_node, current_node.idx_data_points
)

left_node_value = self.draw_leaf_value(left_node_idx_data_points)
right_node_value = self.draw_leaf_value(right_node_idx_data_points)

new_left_node = LeafNode(
index=current_node.get_idx_left_child(),
value=left_node_value,
idx_data_points=left_node_idx_data_points,
)
new_right_node = LeafNode(
index=current_node.get_idx_right_child(),
value=right_node_value,
idx_data_points=right_node_idx_data_points,
)
tree.grow_tree(index_leaf_node, new_split_node, new_left_node, new_right_node)

return True, index_selected_predictor

def get_new_idx_data_points(self, current_split_node, idx_data_points):
idx_split_variable = current_split_node.idx_split_variable
split_value = current_split_node.split_value

left_idx = self.X[idx_data_points, idx_split_variable] <= split_value
left_node_idx_data_points = idx_data_points[left_idx]
right_node_idx_data_points = idx_data_points[~left_idx]

return left_node_idx_data_points, right_node_idx_data_points

def get_residuals(self):
"""Compute the residuals."""
R_j = self.Y - self.sum_trees_output
return R_j

def get_residuals_loo(self, tree):
"""Compute the residuals without leaving the passed tree out."""
R_j = self.Y - (self.sum_trees_output - tree.predict_output(self.num_observations))
return R_j

def draw_leaf_value(self, idx_data_points):
"""Draw the residual mean."""
R_j = self.get_residuals()[idx_data_points]
draw = self.mean(R_j)
return draw

def predict(self, X_new):
"""Compute out of sample predictions evaluated at X_new"""
trees = self.all_trees
num_observations = X_new.shape[0]
pred = np.zeros((len(trees), num_observations))
np.random.randint(len(trees))
for draw, trees_to_sum in enumerate(trees):
new_Y = np.zeros(num_observations)
for tree in trees_to_sum:
new_Y += [tree.predict_out_of_sample(x) for x in X_new]
pred[draw] = new_Y
return pred


def compute_prior_probability(alpha):
class BARTRV(RandomVariable):
"""
Calculate the probability of the node being a LeafNode (1 - p(being SplitNode)).
Taken from equation 19 in [Rockova2018].
Parameters
----------
alpha : float
Returns
-------
list with probabilities for leaf nodes
References
----------
.. [Rockova2018] Veronika Rockova, Enakshi Saha (2018). On the theory of BART.
arXiv, `link <https://arxiv.org/abs/1810.00787>`__
Base class for BART
"""
prior_leaf_prob = [0]
depth = 1
while prior_leaf_prob[-1] < 1:
prior_leaf_prob.append(1 - alpha ** depth)
depth += 1
return prior_leaf_prob


def fast_mean():
"""If available use Numba to speed up the computation of the mean."""
try:
from numba import jit
except ImportError:
return np.mean

@jit
def mean(a):
count = a.shape[0]
suma = 0
for i in range(count):
suma += a[i]
return suma / count

return mean

name = "BART"
ndim_supp = 1
ndims_params = [2, 1, 0, 0, 0, 1]
dtype = "floatX"
_print_name = ("BART", "\\operatorname{BART}")
all_trees = None

def _shape_from_params(self, dist_params, rep_param_idx=1, param_shapes=None):
return default_shape_from_params(self.ndim_supp, dist_params, rep_param_idx, param_shapes)

@classmethod
def rng_fn(cls, rng=np.random.default_rng(), *args, **kwargs):
size = kwargs.pop("size", None)
X_new = kwargs.pop("X_new", None)
all_trees = cls.all_trees
if all_trees:

if size is None:
size = ()
elif isinstance(size, int):
size = [size]

flatten_size = 1
for s in size:
flatten_size *= s

idx = rng.randint(len(all_trees), size=flatten_size)

if X_new is None:
pred = np.zeros((flatten_size, all_trees[0][0].num_observations))
for ind, p in enumerate(pred):
for tree in all_trees[idx[ind]]:
p += tree.predict_output()
else:
pred = np.zeros((flatten_size, X_new.shape[0]))
for ind, p in enumerate(pred):
for tree in all_trees[idx[ind]]:
p += np.array([tree.predict_out_of_sample(x) for x in X_new])
return pred.reshape((*size, -1))
else:
return np.full_like(cls.Y, cls.Y.mean())

def discrete_uniform_sampler(upper_value):
"""Draw from the uniform distribution with bounds [0, upper_value)."""
return int(np.random.random() * upper_value)


class SampleSplittingVariable:
def __init__(self, prior, num_variates):
self.prior = prior
self.num_variates = num_variates

if self.prior is not None:
self.prior = np.asarray(self.prior)
self.prior = self.prior / self.prior.sum()
if self.prior.size != self.num_variates:
raise ValueError(
f"The size of split_prior ({self.prior.size}) should be the "
f"same as the number of covariates ({self.num_variates})"
)
self.enu = list(enumerate(np.cumsum(self.prior)))

def rvs(self):
if self.prior is None:
return int(np.random.random() * self.num_variates)
else:
r = np.random.random()
for i, v in self.enu:
if r <= v:
return i
bart = BARTRV()


class BART(BaseBART):
class BART(NoDistribution):
"""
BART distribution.
Bayesian Additive Regression Tree distribution.
Distribution representing a sum over trees
Parameters
----------
X : array-like
The design matrix.
The covariate matrix.
Y : array-like
The response vector.
m : int
Number of trees
alpha : float
Control the prior probability over the depth of the trees. Must be in the interval (0, 1),
altought it is recomenned to be in the interval (0, 0.5].
Control the prior probability over the depth of the trees. Even when it can takes values in
the interval (0, 1), it is recommended to be in the interval (0, 0.5].
k : float
Scale parameter for the values of the leaf nodes. Defaults to 2. Recomended to be between 1
and 3.
split_prior : array-like
Each element of split_prior should be in the [0, 1] interval and the elements should sum
to 1. Otherwise they will be normalized.
Defaults to None, all variable have the same a prior probability
Each element of split_prior should be in the [0, 1] interval and the elements should sum to
1. Otherwise they will be normalized.
Defaults to None, i.e. all covariates have the same prior probability to be selected.
"""

def __init__(self, X, Y, m=200, alpha=0.25, split_prior=None):
super().__init__(X, Y, m, alpha, split_prior)
def __new__(
cls,
name,
X,
Y,
m=50,
alpha=0.25,
k=2,
split_prior=None,
**kwargs,
):

cls.all_trees = []

bart_op = type(
f"BART_{name}",
(BARTRV,),
dict(
name="BART",
all_trees=cls.all_trees,
inplace=False,
initval=Y.mean(),
X=X,
Y=Y,
m=m,
alpha=alpha,
k=k,
split_prior=split_prior,
),
)()

NoDistribution.register(BARTRV)

cls.rv_op = bart_op
params = [X, Y, m, alpha, k]
return super().__new__(cls, name, *params, **kwargs)

@classmethod
def dist(cls, *params, **kwargs):
return super().dist(params, **kwargs)
Loading

0 comments on commit ff45994

Please sign in to comment.