I am trying to implement a custom loss function for a variant of ViT, where the output is a prediction for each patch from the original image.
The input is of shape [BxCxHxW] and the label for each image is of shape [BxNxRHxRW] where N is the number of classes, RH and RW are reshaped patch size, this is better understood using the following example:
Input shape is [5x1x512x512] with 4 classes. Computing patches (patch size is 32x32) results in [5x256x1024] (ignoring the class token for now). The output from the transformer is reshaped into [5x4x16x16] which is also the final output of the model.
The label is of shape [5x4x16x16] where the second dimension is a probability vector of size 4 representing the probabilities of each class (this is done for all patches).
I am trying to compute the loss between the output and the label. Using nn.CrossEntropyLoss does not work because of the extra dimension. I created (borrowed from a similar project) a custom loss function:
class PatchCrossEntropy(nn.Module): def __init__(self): super(PatchCrossEntropy, self).__init__() def forward(self, x, target): loss = torch.sum(-target * F.log_softmax(x, dim=1), dim=1) return loss.mean() my_criterion = PatchCrossEntropy() x = np.random.RandomState(42).normal(size=(batch_size,num_patches,num_classes)) x = x/x.sum(axis=2, keepdims=True) y = np.random.RandomState(41).normal(size=(batch_size,num_patches,num_classes)) y = y/y.sum(axis=2, keepdims=True) label = torch.as_tensor(x) target = torch.as_tensor(y) loss1 = my_criterion(label,label) print(loss1)
I expect that the output to be 0 but its not. I think that the problem is because of which dim I am summing from. I also don’t know if I should be using the softmax in the first place because the last layer of the model is a softmax layer, so I think an extra one inside the loss function is not needed.
I can share more code if it helps illustrate the problem more, but this is part of a project so some parts of the code I might not be able to share.