Hi all, I’m making a network which predict objects in output grids so that the output shape is [class_size, grid_width, grid_height]. Then now the loss function is like this. output is [batch_size, class_size, grid_width, grid_height] and target [1, grid_width, grid_width] contains class labels if there is an object in a grid otherwise 0.
for b, w, h in torch.nonzero(target.data):
loss += F.cross_entropy(output[b, :, w, h].contiguous().view(1, -1), target[b, w, h])
As you may think, this loss function is very slow and the bottleneck so that I’d like to change if possible.
The reason why the target is tensor is I use the DataLoader and it does not accept some objects, for example. list which store grid and corresponding class label like list(dict(grid=(x,y), cls=c). dict(grid=(x,y), cls=c)...) for each image.
Does anyone have good idea for improvement the performance? Thank you for advance.
Use loss = ce_loss(output.transpose(0,2,3,1)).continguous().view(-1,1), target.view(-1))
The only thing you would need to take care of is the scaling of the loss - I think that by default it would now average over target.size() instead of (target!=0).sum() - so you would need to scale by the quotient.
Best regards
Thomas
Edit: The transpose above is incorrect, I had numpy’s transpose in mind. See below for what I actually wanted.
I’m afraid my explanation was not enough and so confused you. Suppose grid_size=3 and only one object whose label is 3 is in grid [1,2] then the target is
[[0,0,0],
[0,0,0],
[0,3,0]]
The cross entropy loss is ce(output[:, 1, 2], [3]). I used nonzero to get all the indices of object existing grid(s) but that was a bit slow
Indeed, you confused me.
My impression was that you had data similar to
scores = Variable(torch.randn(2,4,3,3))
targets = torch.zeros(1,3,3).long()
targets[:,2,1] = 3
targets = Variable(targets)
targets = targets.expand(scores.size(0),*targets.size()[1:]) # if you only have one target per batch
Thank you Thomas, and I have a question for your solution. I don’t know why you put weight in CrossEntropyLoss, though it may be trivial. -> finally I understand 0 is the background so that you need to ignore. Thank you again, it solved.