SGD, SGD with momentum, and Newton’s Method

source("http://www.rob-mcculloch.org/2023_cs/webpage/hw/logit-funs.R")
## Now let’s simulated some data.

n=200
beta=c(1,2)
p = length(beta)

set.seed(17)
x1 = rnorm(n)
wht = .7
x2 = wht*x1 + (1-wht)*rnorm(n)
print(cor(x1,x2))
## [1] 0.938045
## [1] 0.938045

X = cbind(x1,x2)
y = simData(X,beta)

plot(x1,x2)

##  logit mle
ddf = data.frame(X,y)
lgm = glm(y~.-1,ddf,family='binomial')
print(summary(lgm))
## 
## Call:
## glm(formula = y ~ . - 1, family = "binomial", data = ddf)
## 
## Deviance Residuals: 
##      Min        1Q    Median        3Q       Max  
## -2.45597  -0.52606   0.04267   0.57679   2.66536  
## 
## Coefficients:
##    Estimate Std. Error z value Pr(>|z|)  
## x1   1.2602     0.5781   2.180   0.0293 *
## x2   1.9116     0.7703   2.482   0.0131 *
## ---
## Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
## 
## (Dispersion parameter for binomial family taken to be 1)
## 
##     Null deviance: 277.26  on 200  degrees of freedom
## Residual deviance: 147.49  on 198  degrees of freedom
## AIC: 151.49
## 
## Number of Fisher Scoring iterations: 6
## mle from R
bhat = lgm$coef

#check
cat("grad a mle: ",lgGrad(X,y,bhat),"\n")
## grad a mle:  -7.884589e-14 -6.322526e-14
### p=2, compute mll on grid

gs = 50 #one d grid size
alga1 = seq(from=-0,to=3.0,length.out=gs) #one d grid
alga2 = seq(from=-0,to=3.0,length.out=gs) #one d grid
alg2 = expand.grid(alga1,alga2) # two d grid


llv = rep(0,nrow(alg2))
quadv = rep(0,nrow(alg2))
linv = rep(0,nrow(alg2))
for(i in 1:length(llv)) {
   if(i %% 100 == 0) cat(i," out of ",length(llv),"\n")
   llv[i] = mLL(X,y,as.double(alg2[i,]))
}
## 100  out of  2500 
## 200  out of  2500 
## 300  out of  2500 
## 400  out of  2500 
## 500  out of  2500 
## 600  out of  2500 
## 700  out of  2500 
## 800  out of  2500 
## 900  out of  2500 
## 1000  out of  2500 
## 1100  out of  2500 
## 1200  out of  2500 
## 1300  out of  2500 
## 1400  out of  2500 
## 1500  out of  2500 
## 1600  out of  2500 
## 1700  out of  2500 
## 1800  out of  2500 
## 1900  out of  2500 
## 2000  out of  2500 
## 2100  out of  2500 
## 2200  out of  2500 
## 2300  out of  2500 
## 2400  out of  2500 
## 2500  out of  2500
## contour 
llmat = matrix(as.numeric(llv),nrow=length(alga1),ncol=length(alga2))
contour(alga1,alga2,llmat,nlevels=100,drawlabels=FALSE,
     xlab=expression(beta[1]),ylab=expression(beta[2]),col="blue",
           cex.lab=1.8,cex.axis=1.2,lwd=1)
points(beta[1],beta[2],cex=1.5,pch=16,col="red")
points(bhat[1],bhat[2],cex=1.5,pch=16,col="magenta")
title(main="contours of -log likelihood, red dot at true value, magenta at mle", cex.main=.8)

### stochastic gradient descent

## batch size and number of epochs
bs = 10
nbatch = n/bs
nepoch = 50
niter = nepoch*nbatch

## starting values
biter = c(1.2,.2) #starting value
bM = matrix(0.0,niter+1,length(biter))
bM[1,] = biter
lvv = rep(0,niter+1)
lvv[1] = mLL(X,y,biter)

## constant learning rate
lrat = rep(5,nepoch) #learning rate

## keep track of epoch
inum = 1
cv = rep(0,niter+1) #used to index epochs which is used to set colors

for(j in 1:nepoch) {
   cat("** on epoch: ",j,"\n")
   for(i in 1:nbatch) {
      ii = ((i-1)*bs+1):(i*bs)
      gv = lgGrad(X[ii,],y[ii],biter)
      biter = biter - lrat[j]*gv
      lvv[inum+1] = mLL(X,y,biter)
      bM[inum+1,] = biter
      cv[inum+1]=j
      inum = inum + 1
   }
}
## ** on epoch:  1 
## ** on epoch:  2 
## ** on epoch:  3 
## ** on epoch:  4 
## ** on epoch:  5 
## ** on epoch:  6 
## ** on epoch:  7 
## ** on epoch:  8 
## ** on epoch:  9 
## ** on epoch:  10 
## ** on epoch:  11 
## ** on epoch:  12 
## ** on epoch:  13 
## ** on epoch:  14 
## ** on epoch:  15 
## ** on epoch:  16 
## ** on epoch:  17 
## ** on epoch:  18 
## ** on epoch:  19 
## ** on epoch:  20 
## ** on epoch:  21 
## ** on epoch:  22 
## ** on epoch:  23 
## ** on epoch:  24 
## ** on epoch:  25 
## ** on epoch:  26 
## ** on epoch:  27 
## ** on epoch:  28 
## ** on epoch:  29 
## ** on epoch:  30 
## ** on epoch:  31 
## ** on epoch:  32 
## ** on epoch:  33 
## ** on epoch:  34 
## ** on epoch:  35 
## ** on epoch:  36 
## ** on epoch:  37 
## ** on epoch:  38 
## ** on epoch:  39 
## ** on epoch:  40 
## ** on epoch:  41 
## ** on epoch:  42 
## ** on epoch:  43 
## ** on epoch:  44 
## ** on epoch:  45 
## ** on epoch:  46 
## ** on epoch:  47 
## ** on epoch:  48 
## ** on epoch:  49 
## ** on epoch:  50
plot(lvv)

## contour 
llmat = matrix(as.numeric(llv),nrow=length(alga1),ncol=length(alga2))
contour(alga1,alga2,llmat,nlevels=100,drawlabels=FALSE,
     xlab=expression(beta[1]),ylab=expression(beta[2]),col="blue",
           cex.lab=1.8,cex.axis=1.2,lwd=1)
points(beta[1],beta[2],cex=1.5,pch=16,col="red")
points(bhat[1],bhat[2],cex=1.5,pch=16,col="magenta")
fnm=paste0("Stochastic Grad descent, learning rate= ",lrat[1],", nepoch= ",nepoch,", batch size= ",bs)
title(main=fnm, cex.main=.8)

for(i in 1:(niter+1)) {
 points(bM[i,1],bM[i,2],col=cv[i],pch=16)
}

#sgd with momemtum

## batch / epoch choices
bs = 10
nbatch = n/bs
nepoch = 50
niter = nepoch*nbatch

## momentum
## v <- gam v - eta grad
gam = .9
eta = 5.0
gammav = rep(gam,niter)
etav = rep(eta,niter) # works

##  init beta and storage
biter = c(1.2,.2) #starting value
bM = matrix(0.0,niter+1,length(biter))
bM[1,] = biter
lvv = rep(0,niter+1)
lvv[1] = mLL(X,y,biter)

## init for mom and counter
v = matrix(c(0,0),ncol=1)
inum = 1
cv = rep(0,niter+1) #used to index epochs
for(j in 1:nepoch) {
   cat("** on epoch: ",j,"\n")
   for(i in 1:nbatch) {
      ii = ((i-1)*bs+1):(i*bs)
      gv = lgGrad(X[ii,],y[ii],biter)
      v = gammav[i]*v - etav[i]*gv
      biter = biter + v

      lvv[inum+1] = mLL(X,y,biter)
      bM[inum+1,] = biter
      cv[inum+1]=j
      inum = inum + 1
   }
}
## ** on epoch:  1 
## ** on epoch:  2 
## ** on epoch:  3 
## ** on epoch:  4 
## ** on epoch:  5 
## ** on epoch:  6 
## ** on epoch:  7 
## ** on epoch:  8 
## ** on epoch:  9 
## ** on epoch:  10 
## ** on epoch:  11 
## ** on epoch:  12 
## ** on epoch:  13 
## ** on epoch:  14 
## ** on epoch:  15 
## ** on epoch:  16 
## ** on epoch:  17 
## ** on epoch:  18 
## ** on epoch:  19 
## ** on epoch:  20 
## ** on epoch:  21 
## ** on epoch:  22 
## ** on epoch:  23 
## ** on epoch:  24 
## ** on epoch:  25 
## ** on epoch:  26 
## ** on epoch:  27 
## ** on epoch:  28 
## ** on epoch:  29 
## ** on epoch:  30 
## ** on epoch:  31 
## ** on epoch:  32 
## ** on epoch:  33 
## ** on epoch:  34 
## ** on epoch:  35 
## ** on epoch:  36 
## ** on epoch:  37 
## ** on epoch:  38 
## ** on epoch:  39 
## ** on epoch:  40 
## ** on epoch:  41 
## ** on epoch:  42 
## ** on epoch:  43 
## ** on epoch:  44 
## ** on epoch:  45 
## ** on epoch:  46 
## ** on epoch:  47 
## ** on epoch:  48 
## ** on epoch:  49 
## ** on epoch:  50
plot(lvv)

## contour 
llmat = matrix(as.numeric(llv),nrow=length(alga1),ncol=length(alga2))
contour(alga1,alga2,llmat,nlevels=100,drawlabels=FALSE,
     xlab=expression(beta[1]),ylab=expression(beta[2]),col="blue",
           cex.lab=1.8,cex.axis=1.2,lwd=1)
points(beta[1],beta[2],cex=1.5,pch=16,col="blue")
points(bhat[1],bhat[2],cex=1.5,pch=16,col="magenta")
fnm=paste0("SGD with momentum, learning rate= ",etav[1],", gamma: ",gammav[1], " niter= ",niter)
title(main=fnm, cex.main=.8)
for(i in 2:(niter+1)) {
   arrows(bM[i-1,1],bM[i-1,2],bM[i,1],bM[i,2],length=.05,col=cv[i])
}

## save
slvv = lvv
sbM = bM
### newton's method


##  init beta and storage
niter=20
biter = c(1.2,.2) #starting value
bM = matrix(0.0,niter+1,length(biter))
bM[1,] = biter
lvv = rep(0,niter+1)
lvv[1] = mLL(X,y,biter)

## check hessian is pd
tmp = eigen(lgH(X,bhat))
print(tmp)
## eigen() decomposition
## $values
## [1] 0.054139182 0.005986287
## 
## $vectors
##            [,1]       [,2]
## [1,] -0.8212624  0.5705507
## [2,] -0.5705507 -0.8212624
for(i in 1:niter) {
   H = lgH(X,biter)
   gv = lgGrad(X,y,biter)
   biter = biter - solve(H) %*% gv
   lvv[i+1] = mLL(X,y,biter)
   bM[i+1,] = biter
}
plot(lvv)
abline(h=mLL(X,y,bhat),col='red',lwd=2)

llmat = matrix(as.numeric(llv),nrow=length(alga1),ncol=length(alga2))
contour(alga1,alga2,llmat,nlevels=100,drawlabels=FALSE,
     xlab=expression(beta[1]),ylab=expression(beta[2]),col="blue",
           cex.lab=1.8,cex.axis=1.2,lwd=1)
points(beta[1],beta[2],cex=1.5,pch=16,col="blue")
points(bhat[1],bhat[2],cex=1.5,pch=16,col="magenta")
fnm=paste0("Newton's method")
title(main=fnm, cex.main=.8)
for(i in 2:5) {
   arrows(bM[i-1,1],bM[i-1,2],bM[i,1],bM[i,2],length=.05,col=cv[i])
}

print(cbind(bhat,bM[niter+1,]))
##        bhat         
## x1 1.260224 1.260224
## x2 1.911650 1.911650
newtF = function(X,y,niter=10,biter=rep(0,ncol(X))) {
   bM = matrix(0.0,niter+1,length(biter))
   bM[1,] = biter
   lvv = rep(0,niter+1)
   lvv[1] = mLL(X,y,biter)
   for(i in 1:niter) {
      H = lgH(X,biter)
      gv = lgGrad(X,y,biter)
      biter = biter - solve(H) %*% gv
      lvv[i+1] = mLL(X,y,biter)
      bM[i+1,] = biter
   }
   return(list(mll=lvv,bM=bM))
}

temp = newtF(X,y)
print(temp)
## $mll
##  [1] 0.6931472 0.4263295 0.3775263 0.3691392 0.3687194 0.3687179 0.3687179
##  [8] 0.3687179 0.3687179 0.3687179 0.3687179
## 
## $bM
##            [,1]      [,2]
##  [1,] 0.0000000 0.0000000
##  [2,] 0.6765999 0.7529527
##  [3,] 0.9843015 1.4107285
##  [4,] 1.1960377 1.7928723
##  [5,] 1.2565111 1.9043279
##  [6,] 1.2602107 1.9116221
##  [7,] 1.2602239 1.9116498
##  [8,] 1.2602239 1.9116498
##  [9,] 1.2602239 1.9116498
## [10,] 1.2602239 1.9116498
## [11,] 1.2602239 1.9116498
print(bhat)
##       x1       x2 
## 1.260224 1.911650
sgdF = function(X,y,nepoch=20,bs=10,gam=.8,eta=1,biter=NA) {
   ## batch / epoch choices
   n = length(y)
   p = ncol(X)
   nbatch = floor(n/bs)
   niter = nepoch*nbatch

   ## if no biter, random starting values
   if(is.na(biter[1])) {
      biter = -.8 + 1.6*runif(ncol(X))
   }
   cat("starting biter: \n")
   print(biter)

   ## momentum
   ## v <- gam v - eta grad
   gammav = rep(gam,niter)
   etav = rep(eta,niter)

   ##  init beta and storage
   bM = matrix(0.0,niter+1,length(biter))
   bM[1,] = biter
   lvv = rep(0,niter+1)
   lvv[1] = mLL(X,y,biter)

   ## init for mom and counter
   v = matrix(rep(0,p),ncol=1)
   inum = 1
   cv = rep(0,niter+1) #used to index epochs
   for(j in 1:nepoch) {
      if(j %% 50 ==0) cat("** on epoch: ",j,"\n")
      for(i in 1:nbatch) {
         bbegin = (i-1)*bs + 1
         if(i == nbatch) {
            bend = n
         } else {
            bend = i*bs
         }
         ii = bbegin:bend
         gv = lgGrad(X[ii,],y[ii],biter)
         v = gammav[i]*v - etav[i]*gv
         biter = biter + v

         lvv[inum+1] = mLL(X,y,biter)
         bM[inum+1,] = biter
         cv[inum+1]=j
         inum = inum + 1
      }
   }
   return(list(mll=lvv,bM=bM,epind = cv))
}
## check function gives same result as the loop above
temp = sgdF(X,y,nepoch=50,eta=5,gam=.9,bs=10,biter=c(1.2,.2))
## starting biter: 
## [1] 1.2 0.2
## ** on epoch:  50
plot(temp$mll,slvv)
abline(0,1,col="red")

## check phat at mle is similar to that found by final iteration of sgd
ph = phat(X,bhat)
psgd = phat(X,temp$bM[nrow(temp$bM),])
plot(ph,psgd)
abline(0,1,col="red")

### try p > 2

obeta = beta
p=10
n = 200
beta = rep(0,p)
beta[1:2] = obeta
X = matrix(rnorm(n*p),ncol=p)

y = simData(X,beta)
##  logit mle
ddf = data.frame(X,y)
lgm = glm(y~.-1,ddf,family='binomial')
print(summary(lgm))
## 
## Call:
## glm(formula = y ~ . - 1, family = "binomial", data = ddf)
## 
## Deviance Residuals: 
##     Min       1Q   Median       3Q      Max  
## -2.1800  -0.6271   0.2870   0.8447   2.0952  
## 
## Coefficients:
##     Estimate Std. Error z value Pr(>|z|)    
## X1   0.97909    0.20867   4.692 2.70e-06 ***
## X2   1.72486    0.27390   6.297 3.03e-10 ***
## X3  -0.09146    0.19437  -0.471    0.638    
## X4   0.15120    0.18245   0.829    0.407    
## X5   0.06151    0.17179   0.358    0.720    
## X6  -0.26169    0.17920  -1.460    0.144    
## X7   0.01987    0.20605   0.096    0.923    
## X8   0.16327    0.19605   0.833    0.405    
## X9  -0.03384    0.17333  -0.195    0.845    
## X10 -0.11602    0.19912  -0.583    0.560    
## ---
## Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
## 
## (Dispersion parameter for binomial family taken to be 1)
## 
##     Null deviance: 277.26  on 200  degrees of freedom
## Residual deviance: 189.04  on 190  degrees of freedom
## AIC: 209.04
## 
## Number of Fisher Scoring iterations: 5
## mle from R
bhat = lgm$coef

#check
cat("grad a mle: ",lgGrad(X,y,bhat),"\n")
## grad a mle:  -1.322171e-12 -3.801243e-12 -1.297912e-13 -9.644566e-13 3.087439e-13 4.814926e-13 6.114585e-14 -1.013908e-12 1.002948e-12 -6.876095e-13
## newton
temp = newtF(X,y)
cbind(bhat,temp$bM[nrow(temp$bM),])
##            bhat            
## X1   0.97908842  0.97908842
## X2   1.72485582  1.72485582
## X3  -0.09146346 -0.09146346
## X4   0.15120089  0.15120089
## X5   0.06150654  0.06150654
## X6  -0.26169139 -0.26169139
## X7   0.01987066  0.01987066
## X8   0.16327134  0.16327134
## X9  -0.03384050 -0.03384050
## X10 -0.11602227 -0.11602227
## sgd

xlio = sgdF(X,y,nepoch=5000,eta=1,gam=.9,bs=10)
## starting biter: 
##  [1] -0.2047123  0.4744389  0.4425000 -0.5777517  0.6649311  0.4235054
##  [7] -0.6803333 -0.7427743  0.1625757  0.1367371
## ** on epoch:  50 
## ** on epoch:  100 
## ** on epoch:  150 
## ** on epoch:  200 
## ** on epoch:  250 
## ** on epoch:  300 
## ** on epoch:  350 
## ** on epoch:  400 
## ** on epoch:  450 
## ** on epoch:  500 
## ** on epoch:  550 
## ** on epoch:  600 
## ** on epoch:  650 
## ** on epoch:  700 
## ** on epoch:  750 
## ** on epoch:  800 
## ** on epoch:  850 
## ** on epoch:  900 
## ** on epoch:  950 
## ** on epoch:  1000 
## ** on epoch:  1050 
## ** on epoch:  1100 
## ** on epoch:  1150 
## ** on epoch:  1200 
## ** on epoch:  1250 
## ** on epoch:  1300 
## ** on epoch:  1350 
## ** on epoch:  1400 
## ** on epoch:  1450 
## ** on epoch:  1500 
## ** on epoch:  1550 
## ** on epoch:  1600 
## ** on epoch:  1650 
## ** on epoch:  1700 
## ** on epoch:  1750 
## ** on epoch:  1800 
## ** on epoch:  1850 
## ** on epoch:  1900 
## ** on epoch:  1950 
## ** on epoch:  2000 
## ** on epoch:  2050 
## ** on epoch:  2100 
## ** on epoch:  2150 
## ** on epoch:  2200 
## ** on epoch:  2250 
## ** on epoch:  2300 
## ** on epoch:  2350 
## ** on epoch:  2400 
## ** on epoch:  2450 
## ** on epoch:  2500 
## ** on epoch:  2550 
## ** on epoch:  2600 
## ** on epoch:  2650 
## ** on epoch:  2700 
## ** on epoch:  2750 
## ** on epoch:  2800 
## ** on epoch:  2850 
## ** on epoch:  2900 
## ** on epoch:  2950 
## ** on epoch:  3000 
## ** on epoch:  3050 
## ** on epoch:  3100 
## ** on epoch:  3150 
## ** on epoch:  3200 
## ** on epoch:  3250 
## ** on epoch:  3300 
## ** on epoch:  3350 
## ** on epoch:  3400 
## ** on epoch:  3450 
## ** on epoch:  3500 
## ** on epoch:  3550 
## ** on epoch:  3600 
## ** on epoch:  3650 
## ** on epoch:  3700 
## ** on epoch:  3750 
## ** on epoch:  3800 
## ** on epoch:  3850 
## ** on epoch:  3900 
## ** on epoch:  3950 
## ** on epoch:  4000 
## ** on epoch:  4050 
## ** on epoch:  4100 
## ** on epoch:  4150 
## ** on epoch:  4200 
## ** on epoch:  4250 
## ** on epoch:  4300 
## ** on epoch:  4350 
## ** on epoch:  4400 
## ** on epoch:  4450 
## ** on epoch:  4500 
## ** on epoch:  4550 
## ** on epoch:  4600 
## ** on epoch:  4650 
## ** on epoch:  4700 
## ** on epoch:  4750 
## ** on epoch:  4800 
## ** on epoch:  4850 
## ** on epoch:  4900 
## ** on epoch:  4950 
## ** on epoch:  5000
nb = nrow(xlio$bM)

plot(xlio$mll)

## check phat at mle is similar to that found by final iteration of sgd
ph = phat(X,bhat)
psgd = phat(X,xlio$bM[nb,])
plot(ph,psgd)
abline(0,1,col="red")