Multi-class dice loss memory greedy

Hey,

Thank you for this amazing library, I’m really loving pytorch :slight_smile:

I’m working on segmentation on 3D images (somewhat large: each image is 160x192x160). I implemented my own dice loss:

def loss(pred, target, eps=1e-7):
    # pred: shape is (batch_size, n_classes, 160, 192, 160) (float32 or float 16)
    # target: shape is (batch_size, 160, 192, 160) type is long
    
    target_1_hot = torch.eye(n_classes)[target].type(pred.type()).to(pred.device)
    target_1_hot = target_1_hot.permute(0, 4, 1, 2, 3)

    probas = torch.softmax(pred, dim=1)

    loss = 0.
    batch_size = len(probas)
    for b in range(batch_size): # for loop but for this issue, batch_size is 1 in fact.
        intersections = torch.sum(probas[b, :] * target_1_hot[b, :], (1, 2, 3))
        cardinalities = torch.sum(probas[b, :] + target_1_hot[b, :], (1, 2, 3)) + eps
        loss += (1 - 2 * torch.sum(intersections / cardinalities))

    return loss / batch_size

It works fine for low number of classes, but it is extremely memory greedy when the number of classes increase: with batch size 1 and half precision, it takes 2378 Mb with 10 classes, but already 7700 Mb for 40 classes, according to nvidia-smi.

Do you happen to know a workaround around this ? Can you explain why this operation is so greedy ? I see that I’m declaring quite large tensors (n_classes, 160, 192, 160), but they should not take all that space. Maybe the computation graph stored for the backward is quadratic in these dimensions ?

I tried to checkpoint this operation but it seems sketchy since it’s not within a module. I cannot realistically compute this loss on the cpu, and checkpointing my unet operations is not sufficient. Ideally I’d like to work with ~100 classes ! I also tried to implement the backward (part of it actually: everything AFTER the softmax op) myself, and I do save maybe 5% of the memory, which is still far from enough !

Thank you in advance,
Best,
Maxime

I think the general setup of the loss calculation just uses large tensors and thus a lot of memory.
You can add print statements to check the memory usage, e.g. via:

print(torch.cuda.memory_allocated()/1024**2)

batch_size = 1
n_classes = 10

pred = torch.randn(batch_size, n_classes, 160, 192, 160, device='cuda')
target = torch.randint(0, 2, (batch_size, 160, 192, 160), device='cuda')
print(torch.cuda.memory_allocated()/1024**2)

target_1_hot = torch.eye(n_classes)[target].type(pred.type()).to(pred.device)
target_1_hot = target_1_hot.permute(0, 4, 1, 2, 3)
print(torch.cuda.memory_allocated()/1024**2)

probas = torch.softmax(pred, dim=1)
print(torch.cuda.memory_allocated()/1024**2)

loss = 0.
eps = 1e-6
batch_size = len(probas)
for b in range(batch_size): # for loop but for this issue, batch_size is 1 in fact.
    intersections = torch.sum(probas[b, :] * target_1_hot[b, :], (1, 2, 3))
    cardinalities = torch.sum(probas[b, :] + target_1_hot[b, :], (1, 2, 3)) + eps
    loss += (1 - 2 * torch.sum(intersections / cardinalities))
print(torch.cuda.memory_allocated()/1024**2)

Using this code snippet you’ll get:

0.0
226.0
414.0
602.0
602.00146484375

A simple way to save some memory would be to e.g. reassign pred in:

pred = torch.softmax(pred, dim=1)

and use it in the following calculations, which would save approx. 180MB in this use case.

Thank you very much for your help.

I realize I wasn’t very clear. The forward pass on this operations is fine (although you’re right, I can save a few hundred Mb) but the backward pass is problematic. The memory usages I gave to you were observed during the backward pass. I suspect it does operations somehow quadratic in pred size, but I don’t know any workaround, and I’m going to need to save much more than a few hundreds Mb to make it fit on my 12Gb gpu. Do you know how to save memory for the backward pass in this case ? I’m not against any solution which trades memory for time :slight_smile: