Hello PyTorch community,
I’m currently working on a Bayesian Neural Network (BNN) and I’ve encountered an issue with my custom loss function. I have a BnnLoss class that calculates the base loss and adds a complexity cost (KL divergence) for BNNs. Here’s the code for the class:
class BnnLoss(torch.nn.Module):
def __init__(self, base_loss: torch.nn.Module, kl_weight: float, model: torch.nn.Module = None):
super().__init__()
self.base_loss = base_loss
self.kl_weight = kl_weight
self.model = model if isinstance(model, BnnFourierNet) else None
def forward(self, outputs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
data_loss = self.base_loss(outputs.squeeze(), targets)
if self.model is not None:
complexity_cost = self.model.calculate_kl_divergence()
n_data = targets.shape[0]
return data_loss + (self.kl_weight/n_data) * complexity_cost
else:
return data_loss
When I use this class in my training loop and call total_loss.backward()
, the training result is not as expected. The standard deviation of the model’s output becomes very low, thus acting as a base_loss function, without any modifications.
However, if I calculate the loss directly in the training loop without using the BnnLoss class, everything works as expected. Here is the code of the training cycle, where I try to compare both approaches:
def train_solo_cycle(model, loss, device, conf, trainloader, opt, prof=None):
for ep in range(conf["epochs"]):
start_time = time.time()
for batch in trainloader:
opt.zero_grad()
inputs = batch[0].to(device)
targets = batch[1].to(device)
outputs = model.forward(inputs)
# That loss calculation is implemented directly in the trainig cycle
# I use base_loss (e.g. torch.nn.BCELoss()) here directly. Works as intended
if isinstance(model, BnnFourierNet):
data_loss = loss.base_loss(outputs.squeeze(), targets)
complexity_cost = model.calculate_kl_divergence()
n_data = targets.shape[0]
total_loss_solo_cycle = data_loss + conf["kl_weight"]/n_data * complexity_cost
# Here, i calculate the loss using BnnLoss class
total_loss = loss(outputs.squeeze(), targets)
# That yelds true
torch.eq(total_loss, total_loss_solo_cycle)
# If I use that backward() call from the BnnLoss class, the training results would be the same
# as if I use torch.nn.BCELoss() without any modifications (deviation of the model's output becomes very low)
total_loss.backward()
# If I use that backward() call, from the variable directry in the training cycle
# training results would be as expected
#total_loss_solo_cycle.backward()
opt.step()
return model
What’s strange is that when I compare the tensors generated by the BnnLoss class and in the training loop, they are equal. However, the behavior of the backward operation seems to differ depending on whether it’s called from within the BnnLoss class or directly in the training loop.
I’ve tried using torch.autograd.set_detect_anomaly(True)
to identify the source of the issue, but it didn’t provide any additional insights.
I would appreciate any insights or suggestions on what might be causing this issue and how to resolve it.
Thank you in advance for your help!