SGHMC - does my code do what I think it does?

Greetings,

I try to implement Stochastic Gradient Hamiltonian Monte Carlo (SGHMC) scheme to sample from the Bayesian posterior of neural networks.

I think I implemented everything correctly, but I am a bit worried that PyTorch does stuff “under the hood” that I did not intend.

The update rule of the scheme is given by


where the R_i is a standard normal random variable.
The \theta_i are the network parameters, and the p_i corresponding momenta (in the code, I store them in a buf variable within the parameters.
The nabla U is the stochastic gradient of the loss plus weight decay.

My code assumes a given network model with an evaluate function that computes losses and accuracies of a given data loader.

Does this class look correct?

import torch
import torch.nn as nn
import numpy as np
import time
import copy
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")  ## use GPU if available

class SGHMC(nn.Module):
    
    def __init__(self, model, train_loader, test_loader, criterion, lr, weight_decay, alpha, epochs):
        super(SGHMC, self).__init__()
        self.model = model
        self.train_loader = train_loader
        self.test_loader = test_loader
        self.criterion = criterion
        self.lr = lr
        self.weight_decay = weight_decay
        self.alpha = alpha
        self.epochs = epochs
        
        
    def train(self):
        
        loss_train = np.zeros(self.epochs+1)                 
        accu_train = np.zeros(self.epochs+1)
        loss_test = np.zeros(self.epochs+1)
        accu_test = np.zeros(self.epochs+1)
        
        (loss_train[0], accu_train[0]) = self.model.evaluate(self.train_loader)
        (loss_test[0], accu_test[0]) = self.model.evaluate(self.test_loader)        
        
        datasize = len(self.train_loader.dataset)
        
        squeeze = True if type(self.criterion) == torch.nn.modules.loss.BCELoss else False   # squeeze network output
                                                                                             # for BCELoss (not required
                                                                                             # for NLLLoss)
        # initialize momenta
        for p in self.model.parameters():
            p.buf = torch.randn(p.size()).to(device)

        # train routine
        for epoch in range(1, self.epochs+1):
            self.model.train()
            
            for batch_idx, (data, target) in enumerate(self.train_loader):
                data, target = data.to(device), target.to(device)
                self.model.zero_grad()
                output = self.model(data) 
                if squeeze: output=output.squeeze()
                loss = self.criterion(output, target)*datasize
                loss.backward()
                self.update_params()
        
            (loss_train[epoch], accu_train[epoch]) = self.model.evaluate(self.train_loader)
            (loss_test[epoch], accu_test[epoch]) = self.model.evaluate(self.test_loader)
          
        return (loss_train, loss_test, accu_train, accu_test)


    def update_params(self):

        for p in self.model.parameters():
        
            p.grad.data.add_(p.data, alpha=self.weight_decay)  # adding weight decay to gradients
            
            # update momenta
            eps = torch.randn(p.size()).to(device)  
            p.buf.mul_(1-self.alpha)
            p.buf.add_(-self.lr*p.grad.data + (2.0 * self.lr * self.alpha)**.5 * eps)
            
            # update parameter
            p.data.add_(p.buf)