Loss function for Bayesian Neural Network: backward() issues

Hello PyTorch community,

I’m currently working on a Bayesian Neural Network (BNN) and I’ve encountered an issue with my custom loss function. I have a BnnLoss class that calculates the base loss and adds a complexity cost (KL divergence) for BNNs. Here’s the code for the class:

class BnnLoss(torch.nn.Module):
    def __init__(self, base_loss: torch.nn.Module, kl_weight: float, model: torch.nn.Module = None):
        self.base_loss = base_loss
        self.kl_weight = kl_weight
        self.model = model if isinstance(model, BnnFourierNet) else None

    def forward(self, outputs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
        data_loss = self.base_loss(outputs.squeeze(), targets)

        if self.model is not None:
            complexity_cost = self.model.calculate_kl_divergence()
            n_data = targets.shape[0]
            return data_loss + (self.kl_weight/n_data) * complexity_cost
            return data_loss

When I use this class in my training loop and call total_loss.backward() , the training result is not as expected. The standard deviation of the model’s output becomes very low, thus acting as a base_loss function, without any modifications.

However, if I calculate the loss directly in the training loop without using the BnnLoss class, everything works as expected. Here is the code of the training cycle, where I try to compare both approaches:

def train_solo_cycle(model, loss, device, conf, trainloader, opt, prof=None):
    for ep in range(conf["epochs"]):
        start_time = time.time()
        for batch in trainloader:
            inputs = batch[0].to(device)
            targets = batch[1].to(device)
            outputs = model.forward(inputs)

            # That loss calculation is implemented directly in the trainig cycle
            # I use base_loss (e.g. torch.nn.BCELoss()) here directly. Works as intended
            if isinstance(model, BnnFourierNet):
                data_loss = loss.base_loss(outputs.squeeze(), targets)
                complexity_cost = model.calculate_kl_divergence()
                n_data = targets.shape[0] 
                total_loss_solo_cycle = data_loss + conf["kl_weight"]/n_data * complexity_cost
            # Here, i calculate the loss using BnnLoss class       
            total_loss = loss(outputs.squeeze(), targets)
            # That yelds true      
            torch.eq(total_loss, total_loss_solo_cycle)
            # If I use that backward() call from the BnnLoss class, the training results would be the same              
            # as if I use torch.nn.BCELoss() without any modifications (deviation of the model's output becomes very low)

            # If I use that backward() call, from the variable directry in the training cycle
            # training results would be as expected

    return model

What’s strange is that when I compare the tensors generated by the BnnLoss class and in the training loop, they are equal. However, the behavior of the backward operation seems to differ depending on whether it’s called from within the BnnLoss class or directly in the training loop.

I’ve tried using torch.autograd.set_detect_anomaly(True) to identify the source of the issue, but it didn’t provide any additional insights.

I would appreciate any insights or suggestions on what might be causing this issue and how to resolve it.

Thank you in advance for your help!

Did you verify that your custom BnnLoss method is indeed using the if self.model is not None path?

Yes. I checked that in the debug mode. Also, changing the forward function to the following implementation did not influence on the final result:

def forward(self, outputs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
        complexity_cost = self.model.calculate_kl_divergence()
        n_data = targets.shape[0] 
        return self.base_loss(outputs.squeeze(), targets) + (self.kl_weight/n_data) * complexity_cost

Ok. I found the error. I initialize the BnnLoss with the empty model, but I implement several training sessions, using the copy.deepcopy(base_model) in the other part of the code. So, the BnnLoss never gets the updated status of the model, working constantly on the unprocessed one.