# How to update weights in certain locations?

I have a probability `S` generated by the segmented network (UNet), and ground-truth `G`. A soft dice defined as

Dice_loss = 2* (G * S)/ (|G| +|S|)

Given `S` and `G`, we can compute the soft dice loss using function below. The weights of the segmented network is updated by the gradient respect to `S`. As my understand, it will updated using the whole gradient respect to `S` size of (w,h). My question is that can we control the updating weights at certain positions, i.e. at position at ones pixels in `G` as below example?

``````tensor([[0., 0., 1., 1.],
[0., 1., 1., 1.],
[0., 1., 1., 0.],
[1., 0., 0., 0.]])
weight at conv_1_1 before backward tensor([[ 0.0788,  0.3314, -0.2092,  0.4829],
[-0.4578, -0.4749,  0.3515,  0.2772]])
weight at conv_1_1 after backward tensor([[ 0.0845,  0.3356, -0.2084,  0.4809],
[-0.4635, -0.4792,  0.3507,  0.2791]])
``````

This is my code

``````import torch
from torch import nn
import torch.optim as optim

def make_one_hot(labels, C=2):
one_hot = torch.FloatTensor(labels.size(0), C, labels.size(2), labels.size(3)).zero_()
target = one_hot.scatter_(1, labels.data, 1)
return target

class SoftDiceLoss(nn.Module):
def __init__(self):
super(SoftDiceLoss, self).__init__()

def forward(self, input, target):
smooth = 0.01
input  = input.contiguous().view(-1)
target = make_one_hot(target).contiguous().view(-1)
inter = torch.sum(input * target)
union = torch.sum(input) + torch.sum(target) + smooth
dice = 1 - torch.sum(2.0 * inter / union)
return dice

class MyNet(nn.Module):
def __init__(self):
super(MyNet, self).__init__()
self.conv_3_3 = nn.Conv2d(1, 4, kernel_size=3, padding =1)
self.conv_1_1 = nn.Conv2d(4, 2, kernel_size=1, padding =0)
self.softmax  = nn.Softmax(dim=1)
def forward(self, x):
x = self.conv_3_3(x)
x = self.softmax(self.conv_1_1(x))
return x

b, c, w, h = 1, 1, 4, 4
model = MyNet()
print ('weight at conv_1_1 before backward', model.conv_1_1.weight.data.squeeze())
optimizer = optim.SGD(model.parameters(), lr = 0.01, momentum=0.9)
#print(G)
img = torch.rand((b, c, w, h), requires_grad=True)
# Ground-truth image
G = torch.randint(0, 2, size=(b, c, w, h)).long()
dice_loss = SoftDiceLoss()

for i in range(10):
# Segmented image
S = model(img)