Truncated Normals Using Rejection Sampling

Let’s do the basic example of rejection sampling we looked at in the notes.

We want to draw from a N(0,1), truncated to be bigger than \(a\).

Our proposal is \[ g(x) \sim a + \text{exp}(a) \]

Add \(a\) to an exponential with parameter \(a\).

Remember we accept a draw from \(g\) with probability, othewise we draw again.

\[ h(y) = \frac{f_1(y)}{M g(y)} \]

In this case we computed the optimal \(M\).

##################################################
### compute f1,g, M and plot
a = 3
yv = seq(from=a, to=a+2,length.out=1000)

fy1 = exp(-.5*yv^2)
gy = a*exp(-a*(yv-a))

M = (exp((-a^2)/2)/a)

plot(range(yv),range(c(fy1,M*gy)),type="n")
lines(yv,fy1,lwd=2)
lines(yv,M*gy,col="blue",lwd=3,lty=2)

Wow, that look’s great. \(Mg\) is above \(f_1\) and they look reasonably similar.

Let’s plot just the tail areas so we can see what is going on there.

ylm = range(M*gy[yv>4])
plot(range(yv),range(c(0,fy1,M*gy)),type="n",xlim=c(4,5),ylim=ylm,xlab='y',ylab='Mg and f1')
lines(yv,fy1,lwd=2)
lines(yv,M*gy,col="blue",lwd=3,lty=2)
legend('topright',legend=c('Mg','f1'),col=c('blue','black'),lty=c(2,1),bty='n')

We can see that \(Mg\) stays above \(f_1\).

Let’s compute h and plot,

hycheck = fy1/(M*gy)
hy = exp(-.5*(yv-a)^2)
summary(hy-hycheck)
##       Min.    1st Qu.     Median       Mean    3rd Qu.       Max. 
## -5.551e-16 -1.665e-16  0.000e+00 -1.740e-17  1.110e-16  5.551e-16

plot(yv,hy,xlab="y",ylab="prob of accepting")

This looks good. The probability of accepting is pretty high for the plausible draws from \(g\).

exponential rejection sampling function

Let’s write an R function to do the rejection sampling.

trn = function(a) {
   if(a<2) stop("you have not thought this through")

   done=FALSE

   while(!done) {
      x = rexp(1,rate=a)
      h = exp(-.5*x^2)
      u = runif(1)
      if(u < h) {
         done=TRUE
      }
   }
   return(x+a)
}

We can compare the results from the rejection sampling to “dumb rejection sampling” where we just draws N(0,1) until we get one bigger than \(a\). Would not work for \(a\) very big !!

trnD = function(a) {
   done=FALSE
   while(!done) {
      z = rnorm(1)
      if(z >=a) {
         done=TRUE
      }
   }
   return(z)
}

Lets’ try these with \(a=3\) and 5,000 draws.

a=3
nd = 5000

Draw using exponential rejection:

drv = rep(0,nd)
tm1 = system.time({
for(i in 1:nd) {
   drv[i] = trn(a)
}
})

Draw using dumb rejection.

drvD = rep(0,nd)
tm2 = system.time({
for(i in 1:nd) {
   drvD[i] = trnD(a)
}
})

Draw using dumb rejection vectorized.

tm3 = system.time({
drv1 = rnorm(nd/(1-pnorm(a)))
drv1 = drv1[drv1>=a]
})

Let’s have a look at the draws and check that they look correct.

trden = dnorm(yv)/(1-pnorm(a))
par(mfrow=c(1,2))
hist(drv1,prob=TRUE); lines(yv,trden,col='blue')
hist(drv,prob=TRUE); lines(yv,trden,col='blue')

Check draws with qqplots using the vectorized dumb draws as the gold standard.

par(mfrow=c(1,2))
qqplot(drv1,drv) #compare reject draws to vectorized dumb
abline(0,1,col="red",lwd=3)
qqplot(drv1,drvD) # compare dumbs to vectorized dumb
abline(0,1,col="red",lwd=3)

print(tm1)
##    user  system elapsed 
##   0.045   0.008   0.053
print(tm2)
##    user  system elapsed 
##   6.170   0.008   6.180
print(tm3)
##    user  system elapsed 
##   0.232   0.008   0.240

try inverse CDF

Let’s use the inverse CDF method.

drcdf = function(n,a) {
   u = runif(n)
   Fa = pnorm(a)
   temp = u*(1-Fa) + Fa
   return(qnorm(temp))
}
tm6 = system.time({
   drC = drcdf(nd,a)
})
print(tm6)
##    user  system elapsed 
##   0.000   0.000   0.001

check the draws:

qqplot(drC,drv1)
abline(0,1,col='red')