Thank you for this amazing library, I’m really loving pytorch
I’m working on segmentation on 3D images (somewhat large: each image is 160x192x160). I implemented my own dice loss:
def loss(pred, target, eps=1e-7): # pred: shape is (batch_size, n_classes, 160, 192, 160) (float32 or float 16) # target: shape is (batch_size, 160, 192, 160) type is long target_1_hot = torch.eye(n_classes)[target].type(pred.type()).to(pred.device) target_1_hot = target_1_hot.permute(0, 4, 1, 2, 3) probas = torch.softmax(pred, dim=1) loss = 0. batch_size = len(probas) for b in range(batch_size): # for loop but for this issue, batch_size is 1 in fact. intersections = torch.sum(probas[b, :] * target_1_hot[b, :], (1, 2, 3)) cardinalities = torch.sum(probas[b, :] + target_1_hot[b, :], (1, 2, 3)) + eps loss += (1 - 2 * torch.sum(intersections / cardinalities)) return loss / batch_size
It works fine for low number of classes, but it is extremely memory greedy when the number of classes increase: with batch size 1 and half precision, it takes 2378 Mb with 10 classes, but already 7700 Mb for 40 classes, according to nvidia-smi.
Do you happen to know a workaround around this ? Can you explain why this operation is so greedy ? I see that I’m declaring quite large tensors (n_classes, 160, 192, 160), but they should not take all that space. Maybe the computation graph stored for the backward is quadratic in these dimensions ?
I tried to checkpoint this operation but it seems sketchy since it’s not within a module. I cannot realistically compute this loss on the cpu, and checkpointing my unet operations is not sufficient. Ideally I’d like to work with ~100 classes ! I also tried to implement the backward (part of it actually: everything AFTER the softmax op) myself, and I do save maybe 5% of the memory, which is still far from enough !
Thank you in advance,