Happy New Year!
I have a probability S
generated by the segmented network (UNet), and ground-truth G
. A soft dice defined as
Dice_loss = 2* (G * S)/ (|G| +|S|)
Given S
and G
, we can compute the soft dice loss using function below. The weights of the segmented network is updated by the gradient respect to S
. As my understand, it will updated using the whole gradient respect to S
size of (w,h). My question is that can we control the updating weights at certain positions, i.e. at position at ones pixels in G
as below example?
tensor([[0., 0., 1., 1.],
[0., 1., 1., 1.],
[0., 1., 1., 0.],
[1., 0., 0., 0.]])
weight at conv_1_1 before backward tensor([[ 0.0788, 0.3314, -0.2092, 0.4829],
[-0.4578, -0.4749, 0.3515, 0.2772]])
tensor(0.4853, grad_fn=<RsubBackward1>)
tensor(0.4852, grad_fn=<RsubBackward1>)
tensor(0.4851, grad_fn=<RsubBackward1>)
tensor(0.4849, grad_fn=<RsubBackward1>)
tensor(0.4846, grad_fn=<RsubBackward1>)
tensor(0.4842, grad_fn=<RsubBackward1>)
tensor(0.4839, grad_fn=<RsubBackward1>)
tensor(0.4834, grad_fn=<RsubBackward1>)
tensor(0.4830, grad_fn=<RsubBackward1>)
tensor(0.4824, grad_fn=<RsubBackward1>)
weight at conv_1_1 after backward tensor([[ 0.0845, 0.3356, -0.2084, 0.4809],
[-0.4635, -0.4792, 0.3507, 0.2791]])
This is my code
import torch
from torch import nn
import torch.optim as optim
def make_one_hot(labels, C=2):
one_hot = torch.FloatTensor(labels.size(0), C, labels.size(2), labels.size(3)).zero_()
target = one_hot.scatter_(1, labels.data, 1)
return target
class SoftDiceLoss(nn.Module):
def __init__(self):
super(SoftDiceLoss, self).__init__()
def forward(self, input, target):
smooth = 0.01
input = input.contiguous().view(-1)
target = make_one_hot(target).contiguous().view(-1)
inter = torch.sum(input * target)
union = torch.sum(input) + torch.sum(target) + smooth
dice = 1 - torch.sum(2.0 * inter / union)
return dice
class MyNet(nn.Module):
def __init__(self):
super(MyNet, self).__init__()
self.conv_3_3 = nn.Conv2d(1, 4, kernel_size=3, padding =1)
self.conv_1_1 = nn.Conv2d(4, 2, kernel_size=1, padding =0)
self.softmax = nn.Softmax(dim=1)
def forward(self, x):
x = self.conv_3_3(x)
x = self.softmax(self.conv_1_1(x))
return x
b, c, w, h = 1, 1, 4, 4
model = MyNet()
print ('weight at conv_1_1 before backward', model.conv_1_1.weight.data.squeeze())
optimizer = optim.SGD(model.parameters(), lr = 0.01, momentum=0.9)
#print(G)
img = torch.rand((b, c, w, h), requires_grad=True)
# Ground-truth image
G = torch.randint(0, 2, size=(b, c, w, h)).long()
dice_loss = SoftDiceLoss()
for i in range(10):
# Segmented image
S = model(img)
optimizer.zero_grad()
dice = dice_loss(S, G)
dice.backward()
optimizer.step()
print(dice)
print ('weight at conv_1_1 after backward', model.conv_1_1.weight.data.squeeze())