Would you please show me how to fix this? I changed it like this:
def forward(self, logits, labels):
N, C, H, W = logits.size()
n_pixs = N * H * W
logits = logits.permute(0, 2, 3, 1).contiguous().view(-1, C)
with torch.no_grad():
scores = F.softmax(logits, dim=1).cpu().detach()
labels = labels.view(-1)
labels_cpu = labels.cpu().detach()
invalid_mask = labels_cpu==self.ignore_lb
labels_cpu[invalid_mask] = 0
picks = scores[torch.arange(n_pixs), labels_cpu]
picks[invalid_mask] = 1
sorteds, inds = torch.sort(picks)
thresh = self.thresh if sorteds[self.n_min]<self.thresh else sorteds[n_min]
labels[picks>thresh] = self.ignore_lb
loss = self.criteria(logits, labels)
return loss
But the problem still exists