I have the following custom loss function that operates over the entire batch to produce the loss. I am getting zero gradients and thus wanted to know why autograd doesn’t work for the following.
class PairwiseLoss(nn.Module):
def init (self):
super(PairwiseLoss, self).init ()
def forward(self, outputs, memScores):
outputsi = outputs.expand([outputs.size(0),outputs.size(0)])
outputsj = outputsi.clone()
outputsj.transpose_(0,1)
outputsij = F.sigmoid(outputsi - outputsj)
memScoresi = memScores.expand([memScores.size(0),memScores.size(0)])
memScoresj = memScoresi.clone()
memScoresj.transpose_(0,1)
memScoresij = torch.gt(memScoresi,memScoresj)+torch.ge(memScoresi,memScoresj)
memScoresij = memScoresij.float()
memScoresij = memScoresij/2
Values = outputsij.view(outputsij.size(0)*outputsij.size(1))
labels = memScoresij.view(memScoresij.size(0)*memScoresij.size(1))
BCE = nn.BCELoss()
loss = BCE(Values,labels)
return loss
albanD
(Alban D)
September 24, 2018, 9:51am
2
Hi,
I can’t reproduce your problem. The following code gives some non-zero gradients for the outputs
vector:
import torch
from torch import nn
from torch.nn import functional as F
# This is just your class copy pasted, not modified
class PairwiseLoss(nn.Module):
def __init__(self):
super(PairwiseLoss, self).__init__()
def forward(self, outputs, memScores):
outputsi = outputs.expand([outputs.size(0),outputs.size(0)])
outputsj = outputsi.clone()
outputsj.transpose_(0,1)
outputsij = F.sigmoid(outputsi - outputsj)
memScoresi = memScores.expand([memScores.size(0),memScores.size(0)])
memScoresj = memScoresi.clone()
memScoresj.transpose_(0,1)
memScoresij = torch.gt(memScoresi,memScoresj)+torch.ge(memScoresi,memScoresj)
memScoresij = memScoresij.float()
memScoresij = memScoresij/2
Values = outputsij.view(outputsij.size(0)*outputsij.size(1))
labels = memScoresij.view(memScoresij.size(0)*memScoresij.size(1))
BCE = nn.BCELoss()
loss = BCE(Values,labels)
return loss
outputs = torch.rand(10, 1, requires_grad=True)
memScores = torch.rand(10, 1, requires_grad=True)
loss = PairwiseLoss()(outputs, memScores)
print(loss)
loss.backward()
print(outputs.grad)
print(memScores.grad)
Thanks a lot. I figure there was a problem somewhere else.