Hey,

Thank you for this amazing library, I’m really loving pytorch

I’m working on segmentation on 3D images (somewhat large: each image is 160x192x160). I implemented my own dice loss:

```
def loss(pred, target, eps=1e-7):
# pred: shape is (batch_size, n_classes, 160, 192, 160) (float32 or float 16)
# target: shape is (batch_size, 160, 192, 160) type is long
target_1_hot = torch.eye(n_classes)[target].type(pred.type()).to(pred.device)
target_1_hot = target_1_hot.permute(0, 4, 1, 2, 3)
probas = torch.softmax(pred, dim=1)
loss = 0.
batch_size = len(probas)
for b in range(batch_size): # for loop but for this issue, batch_size is 1 in fact.
intersections = torch.sum(probas[b, :] * target_1_hot[b, :], (1, 2, 3))
cardinalities = torch.sum(probas[b, :] + target_1_hot[b, :], (1, 2, 3)) + eps
loss += (1 - 2 * torch.sum(intersections / cardinalities))
return loss / batch_size
```

It works fine for low number of classes, but it is extremely memory greedy when the number of classes increase: with batch size 1 and half precision, it takes 2378 Mb with 10 classes, but already 7700 Mb for 40 classes, according to nvidia-smi.

Do you happen to know a workaround around this ? Can you explain why this operation is so greedy ? I see that I’m declaring quite large tensors (n_classes, 160, 192, 160), but they should not take all that space. Maybe the computation graph stored for the backward is quadratic in these dimensions ?

I tried to checkpoint this operation but it seems sketchy since it’s not within a module. I cannot realistically compute this loss on the cpu, and checkpointing my unet operations is not sufficient. Ideally I’d like to work with ~100 classes ! I also tried to implement the backward (part of it actually: everything AFTER the softmax op) myself, and I do save maybe 5% of the memory, which is still far from enough !

Thank you in advance,

Best,

Maxime