Variable Section Using BART

We present two approaches.

In our first approach we use the R package BART which is on CRAN.
This approach examines the BART draws to see which variables are being used in the decision trees.

In the second approach we use the R package nonlinvarsel which is available at http://www.rob-mcculloch.org.
This approach searches for variable subsets which can be used to approximate predictions based on the information in the complete set of variables.
See “A Utility Based Approach to Variable Selection” below.

Variable Selection Using the BART R Package

In this note we will illustrate the use of the R package BART to do variable selection.

The basic advantages of this approach are the effective stochastic search of the BART algorithm and the simple variable usage counts.
There is no need to elaborate the model.

The only tricky aspect is that to get a good idea of the variable usage, it is recommended that fewer trees be used in the sum of trees model.

The Data

To keep things simple we will use the time honored Boston Housing data.

Each observation corresponds to a region in the Boston area.
The response is the median house price in the region.
The explanatory features measure various things about the region.

library(MASS)
attach(Boston)
names(Boston)
##  [1] "crim"    "zn"      "indus"   "chas"    "nox"     "rm"      "age"    
##  [8] "dis"     "rad"     "tax"     "ptratio" "black"   "lstat"   "medv"

y = Boston$medv
x = Boston[,1:13]
p = ncol(x)

Running BART

First we load the BART R package and set the seed.

We let burn = 10000 burn-in draws and nd = 10000 kept draws.
For the kept draws of the function, \(f_d\), \(d=1,2,\ldots, \text{nd}\), BART will evaluate \(f_d(x)\) for train and test (if supplied) values of \(x\).

We set the seed.

library(BART)
## Loading required package: nlme
## Loading required package: nnet
## Loading required package: survival

set.seed(99) # Wayne Gretzky
burn=10000; nd=10000

Now we run BART using the default prior and model settings.

# run default BART
bf = wbart(x,y,nskip=burn,ndpost=nd,printevery=5000) #print progress every 5000th MCMC iteration

The sigma component of bf has all 10^{4} + 10^{4} draws of sigma, that is, we include burn-in.
This way we can check that we chose a reasonable value for the number of burn-in draws.

# linear model fit
lmf = lm(medv~.,Boston)
plot(bf$sigma,ylim=c(1.5,5),xlab="MCMC iteration",ylab="sigma draw",cex=.5)
abline(h=summary(lmf)$sigma,col="red",lty=2) #least squares estimates
abline(v = burn,col="green")
title(main="sigma draws, green line at burn in, red line at least squares estimate",cex.main=.8)

Our burn-in choice looks reasonable in that the sigma draws have “settled down” by the 1,000 the draw.
The draws exhibit substantial dependence, but seem like a reasonable exploration of plausible sigma values given the data.
The post burn-in sigma values suggest BART has found substantially more fit on the training data than the simple linear model.

For example, if we thin to every 20th post burn-in draw, we get a non-neglible but reasonable ACF.

thin = 20
ii = burn + thin*(1:(nd/thin))
acf(bf$sigma[ii],main="ACF of thinned post burn-in sigma draws")

The default BART run uses 200 trees in the sum-of-trees model.

To get a good sense of the variable selection, we will see that it helps to use fewer trees.
Let’s run BART, but just use 20 trees in the sum.

# run BART with 20 trees in the sum
set.seed(99)
bf20 = wbart(x,y,nskip=burn,ndpost=nd,ntree=20,printevery=5000)

Let’s compare the in-sample fits from the two BART runs.

fitmat = cbind(y,bf$yhat.train.mean,bf20$yhat.train.mean)
colnames(fitmat) = c("y","yhatBART","yhatBART20")
pairs(fitmat)

print(cor(fitmat))
##                    y  yhatBART yhatBART20
## y          1.0000000 0.9884595  0.9800918
## yhatBART   0.9884595 1.0000000  0.9924667
## yhatBART20 0.9800918 0.9924667  1.0000000

We see that both BART and BART20 fit the data (in-sample) very well.

In general, the BART20 may not may not fit quite as well, but we will see in the next section that it sharpens our variable selection to use fewer trees.

Variable Selection

The component varcount of our BART fit the provides information about which variables are used in the trees.

dim(bf20$varcount)
## [1] 10000    13

Each row of varcount corresponds to one of our 10^{4} kept post burn-in MCMC draws of the sum of tree model.

Each column corresponds to one of our 13 x variables.

We just have to summarize the values in varcount.

#compute row percentages
percount20 = bf20$varcount/apply(bf20$varcount,1,sum)
# mean of row percentages
mvp20 =apply(percount20,2,mean)
#quantiles of row percentags
qm = apply(percount20,2,quantile,probs=c(.05,.95))

print(mvp20)
##       crim         zn      indus       chas        nox         rm        age 
## 0.05088151 0.03225757 0.02158979 0.01286549 0.14385611 0.15979102 0.05695193 
##        dis        rad        tax    ptratio      black      lstat 
## 0.09558875 0.04908700 0.09949849 0.08841762 0.04502477 0.14418993

rgy = range(qm)
plot(c(1,p),rgy,type="n",xlab="variable",ylab="post mean, percent var use",axes=FALSE)
axis(1,at=1:p,labels=names(mvp20),cex.lab=0.7,cex.axis=0.7)
axis(2,cex.lab=1.2,cex.axis=1.2)
lines(1:p,mvp20,col="black",lty=4,pch=4,type="b",lwd=1.5)
for(i in 1:p) {
   lines(c(i,i),qm[,i],col="blue",lty=3,lwd=1.0)
}

We see that nox, rm, and lstat get used more in the trees than the other variables.

Let’s compare the BART default run to the BART20 run.

percount = bf$varcount/apply(bf$varcount,1,sum)
mvp = apply(percount,2,mean)
plot(mvp20,xlab="variable number",ylab="post mean, percent var use",col="blue",type="b")
lines(mvp,type="b",col='red')
legend("topleft",legend=c("BART","BART20"),col=c("red","blue"),lty=c(1,1))

We see that the default BART run uses the variable with heavy shrinkage, so that while a variable is in a tree, the corresponing bottom node mean is shrunk heavily to 0 so that it is of no real importance.

With 20 trees, the means are shrunk less heavily so a variable tends to only come in when it actually does something!!

A Utility Based Approach to Variable Selection

In this section we illustrate a utility based approach to variable selection using BART inference.

The R package is available at http://www.rob-mcculloch.org/.

This approach uses any inference for a nonlinear function \(\hat{y} = \hat{f}(x)\) and then seeks to find a subset of the variables in \(x\) such there there is a nonlinear function such that that \(g_S(x_S) \approx \hat{f}(x)\) where \(x_s\) denotes the variables in the subset. The approximation should be reasonable close as a practical matter. Practical significance is emphasized over statistical significance, instead of the other way around.

If we are able to find such a subset of variables then clearly we can make predictions based on the subset where are as good as those using the full information in \(x\) and \(\hat{f}\).

Many modern statistical/Machine Learning methods give us remarkably interesting \(\hat{f}\) given training data and the method may be applied with any such \(\hat{f}\). In addition, BART gives us draws \(f_d, d=1,2,\ldots, D\) from the posterior of \(f\). In this case we can assess our uncertainty by looking at the values \(d(f_d,g_S), d=1,2,\ldots, D\) where \(d\) is a metric chosen to capture the practical difference in predictions.

Assessing Uncertainty

To assess uncertainty, we look at the posterior distribution of our approximation error. We use the approximations from the BART refits to all subsets search and root mean squared error (and the training data in this case) as a measure of approximation error.

sp = sumpost(bf$yhat.train,bsubs,bf$yhat.train.mean,distrmse)
## nd,n,p:  10000 506 12
plot(sp,diff=FALSE)

It can also be helpful to look at the difference between the approximation error using a variable subset and the approximation error using the full subset.

plot(sp)

summary(y)
##    Min. 1st Qu.  Median    Mean 3rd Qu.    Max. 
##    5.00   17.02   21.20   22.53   25.00   50.00

We see that, as a practical matter, with high posterior probability, we can get good predictions using just three or at most seven, of the variables.

Here are the variable subsets from the all subsets search using 3, 4 and 7 variables.

names(x)[vsar$vL[[3]]]
## [1] "nox"   "rm"    "lstat"
names(x)[vsar$vL[[4]]]
## [1] "nox"   "rm"    "rad"   "lstat"
names(x)[vsar$vL[[7]]]
## [1] "nox"     "rm"      "age"     "dis"     "ptratio" "black"   "lstat"

And forwards:

names(x)[vsfr$vL[[3]]]
## [1] "lstat" "rm"    "nox"
names(x)[vsfr$vL[[4]]]
## [1] "lstat" "rm"    "nox"   "rad"
names(x)[vsfr$vL[[7]]]
## [1] "lstat"   "rm"      "nox"     "rad"     "tax"     "ptratio" "black"