I have two implementations of weighted dice here. I was using the first one but vectorised it later. These are for 3D image of size [bs,5,96,96,64] . They are both giving me different answers. I cannot figure why. Which one is correct?
(All I really did was remove the for loop for a ‘:’ across the channels)
import torch
w=np.array([0.00233549, 0.14019698, 0.04583725, 0.66542272, 0.14620756])
loss_weights=torch.from_numpy(w)
def dice_loss(inpu, target): # [bs,5,96,96,64]
smooth = .1
def calc(iflat,tflat):
intersection = (iflat * tflat).sum()
return (((2. * intersection + smooth) /
(iflat.sum() + tflat.sum() + smooth)))
a=0
b=0
c=0
d=0
e=0
for k in range(inpu.size()[0]):
ip=inpu[k,0,:,:,:].view(-1)
tar=target[k,0,:,:,:].view(-1)
a+=calc(ip,tar)
ip=inpu[k,1,:,:,:].view(-1)
tar=target[k,1,:,:,:].view(-1)
b+=calc(ip,tar)
ip=inpu[k,2,:,:,:].view(-1)
tar=target[k,2,:,:,:].view(-1)
c+=calc(ip,tar)
ip=inpu[k,3,:,:,:].view(-1)
tar=target[k,3,:,:,:].view(-1)
d+=calc(ip,tar)
ip=inpu[k,4,:,:,:].view(-1)
tar=target[k,4,:,:,:].view(-1)
e+=calc(ip,tar)
raw_scores=(loss_weights[0]*a + loss_weights[1]*b + loss_weights[2]*c + loss_weights[3]*d + loss_weights[4]*e)/inpu.size()[0]
return 1.0 - raw_scores
Second
def dice_loss2(inpu, target): # [bs,5,96,96,64]
smooth = .1
def calc(iflat,tflat):
intersection = (iflat * tflat).sum()
return (((2. * intersection + smooth) /
(iflat.sum() + tflat.sum() + smooth)))
ip=inpu[:,0,:,:,:].contiguous().view(-1)
tar=target[:,0,:,:,:].contiguous().view(-1)
a=calc(ip,tar)
ip=inpu[:,1,:,:,:].contiguous().view(-1)
tar=target[:,1,:,:,:].contiguous().view(-1)
b=calc(ip,tar)
ip=inpu[:,2,:,:,:].contiguous().view(-1)
tar=target[:,2,:,:,:].contiguous().view(-1)
c=calc(ip,tar)
ip=inpu[:,3,:,:,:].contiguous().view(-1)
tar=target[:,3,:,:,:].contiguous().view(-1)
d=calc(ip,tar)
ip=inpu[:,4,:,:,:].contiguous().view(-1)
tar=target[:,4,:,:,:].contiguous().view(-1)
e=calc(ip,tar)
raw_scores=(loss_weights[0]*a + loss_weights[1]*b + loss_weights[2]*c + loss_weights[3]*d + loss_weights[4]*e)/inpu.size()[0]
return 1.0 - raw_scores