Accessing weights in custom loss function

Hello again! I am trying to do the following. I would like to do some custom regularisation in my training, and I would like to implement some type of Maximum a Posteriori (MAP) estimation, where I impose a Horsehoe prior on the weights and biases of the model (see for example: https://proceedings.mlr.press/v5/carvalho09a/carvalho09a.pdf). I have been looking for help on how to extract the weights and use them in the loss function, but I could not find anything. With the help of ChatGPT, I came up with this after many trial and errors. I think it is working, as everything gets updated properly in the model, and it makes sense when I use simulated data. But I would appreciate some help on this.

What I did is:

  1. Create a function that takes the weights and biases and puts them together as an updatable tensor, using “views”
def fuse_parameters(model):
    
    """Move model parameters to a contiguous tensor, and return that tensor along with the original views."""
    # Step 1: Calculate the total number of elements in all model parameters
    n = sum(p.numel() for p in model.parameters())
    
    # Step 2: Create a contiguous tensor with the total number of elements, with gradients activated, necessary for optimization
    params = torch.zeros(n, requires_grad=True)
    
    # Step 3: Initialize a list to hold views of the original parameters
    views = []
    
    # Step 4: Initialize the starting index for slicing the contiguous tensor
    i = 0
    
    # Step 5: Iterate over each parameter tensor in the model
    for p in model.parameters():
        
        # Step 6: Get the number of elements in the current parameter tensor
        numel = p.numel()
        
        # Step 7: Create a view into the contiguous tensor for the current parameter
        param_view = params[i:i + numel].view_as(p)
        
        # Step 8: Copy the data from the current parameter tensor into the view
        param_view.data.copy_(p.data)
        
        # Step 9: Append the view to the list of views
        views.extend(param_view.flatten())
        
        # Step 10: Move the starting index forward by the number of elements in the current parameter tensor
        i += numel
    
    # Step 11: Return the contiguous tensor and the list of views
    return params, views

  1. Define a loss function, that will also take the model, and then, put the prior on each of the weights and biases that come out from the previous function
class LinearRegressionHorseshoe(nn.Module):
    
  def __init__(self, nParams, nT):

    super(LinearRegressionHorseshoe, self).__init__()
    
    # Fixed quantities
    self.nT    = nT
    
    # Estimated parameters 
    self.sigma    = nn.Parameter(torch.tensor([0.0])) # Variance of the regression
    self.lambdas  = nn.Parameter(torch.zeros(nParams)) # Local weight variance
    self.tau0     = nn.Parameter(torch.tensor([0.0])) # Global shrinkage


  def forward(self, predictions, targets, nnmodel): 
    
    # Rescale sigma and tau0
    sigma_exp = torch.exp(self.sigma)
    tau0_exp = torch.exp(self.tau0)
    
    # Compute conditional likelihood of the data   
    squared_diff = (predictions - targets)**2
    llk          = - torch.sum(squared_diff) / (2 * sigma_exp ** 2) - self.nT / 2 * torch.log(sigma_exp ** 2)
    
    # Add regularisation for sigma and tau0 (Normal prior, N(0,1))
    p_sigma = - 0.5*(self.sigma ** 2)
    p_tau0 = - 0.5*(self.tau0 ** 2)
    
    # Compute the horseshoe regularisation term
    ss_reg        = 0.0
    params, views = fuse_parameters(nnmodel)
    
    for param, sp in zip(views, self.lambdas):
        
        # Make the scale positive
        esp     = torch.exp(sp)
        
        # Variance of the weight
        vweight = (esp ** 2) * (tau0_exp ** 2)

        # Regularisation term
        ss_reg  += - param.pow(2).sum() / (2 * vweight) - 0.5 * torch.log(vweight) - torch.log(1 + torch.exp(sp)**2) + sp 
   
    
    # Now, add everything and take the minus sign as we want to minimize
    loss = - (llk + p_sigma + p_tau0 + ss_reg)
    return loss

What do you think?