I’m dealing with segmentation now.

I want to use with cross_entropy_loss and dice_loss. The form of input shape is[4,512,512], having value of 0,1 in indices and the output form is [4,2,512,512].

masks_shape --> torch.Size([4, 512, 512])

output_masks_shape --> torch.Size([4, 2,512, 512])

I have got the dice_loss function from another site.

```
def dice_loss(input, target):
smooth = 1.
loss = 0.
for c in range(n_classes):
iflat = input[:, c ].view(-1)
tflat = target[:, c].view(-1)
intersection = (iflat * tflat).sum()
w = class_weights[c]
loss += w*(1 - ((2. * intersection + smooth) /
(iflat.sum() + tflat.sum() + smooth)))
return loss
```

How to reshape input shape [4,512,512] to [4,2,512,512] including indices(0,1)??

And I want to use class_weights because I’m dealing with an unbalanced classes.

How to calculate class_weights??