Greetings,
I try to implement Stochastic Gradient Hamiltonian Monte Carlo (SGHMC) scheme to sample from the Bayesian posterior of neural networks.
I think I implemented everything correctly, but I am a bit worried that PyTorch does stuff “under the hood” that I did not intend.
The update rule of the scheme is given by
where the R_i is a standard normal random variable.
The \theta_i are the network parameters, and the p_i corresponding momenta (in the code, I store them in a
buf
variable within the parameters.The nabla U is the stochastic gradient of the loss plus weight decay.
My code assumes a given network model
with an evaluate
function that computes losses and accuracies of a given data loader.
Does this class look correct?
import torch
import torch.nn as nn
import numpy as np
import time
import copy
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") ## use GPU if available
class SGHMC(nn.Module):
def __init__(self, model, train_loader, test_loader, criterion, lr, weight_decay, alpha, epochs):
super(SGHMC, self).__init__()
self.model = model
self.train_loader = train_loader
self.test_loader = test_loader
self.criterion = criterion
self.lr = lr
self.weight_decay = weight_decay
self.alpha = alpha
self.epochs = epochs
def train(self):
loss_train = np.zeros(self.epochs+1)
accu_train = np.zeros(self.epochs+1)
loss_test = np.zeros(self.epochs+1)
accu_test = np.zeros(self.epochs+1)
(loss_train[0], accu_train[0]) = self.model.evaluate(self.train_loader)
(loss_test[0], accu_test[0]) = self.model.evaluate(self.test_loader)
datasize = len(self.train_loader.dataset)
squeeze = True if type(self.criterion) == torch.nn.modules.loss.BCELoss else False # squeeze network output
# for BCELoss (not required
# for NLLLoss)
# initialize momenta
for p in self.model.parameters():
p.buf = torch.randn(p.size()).to(device)
# train routine
for epoch in range(1, self.epochs+1):
self.model.train()
for batch_idx, (data, target) in enumerate(self.train_loader):
data, target = data.to(device), target.to(device)
self.model.zero_grad()
output = self.model(data)
if squeeze: output=output.squeeze()
loss = self.criterion(output, target)*datasize
loss.backward()
self.update_params()
(loss_train[epoch], accu_train[epoch]) = self.model.evaluate(self.train_loader)
(loss_test[epoch], accu_test[epoch]) = self.model.evaluate(self.test_loader)
return (loss_train, loss_test, accu_train, accu_test)
def update_params(self):
for p in self.model.parameters():
p.grad.data.add_(p.data, alpha=self.weight_decay) # adding weight decay to gradients
# update momenta
eps = torch.randn(p.size()).to(device)
p.buf.mul_(1-self.alpha)
p.buf.add_(-self.lr*p.grad.data + (2.0 * self.lr * self.alpha)**.5 * eps)
# update parameter
p.data.add_(p.buf)