Torch.where() function blocks gradient

Source codes I attached below is a very easy version of my original model.
There are two conditions. First, I need learnable parameters to finetune my weight values and I have to use ‘torch.where()’ function, In this code, I actually used very simple condition x > 0 but the actual condition is more complex than I wrote.

before ---------------------
w   tensor([[[[0.3798]]]])
myParam   tensor([0.9831])
after ---------------------
w   tensor([[[[0.3724]]]])
myParam   tensor([0.9831])

and this is main issue. weight values were updated by loss.backward() and optimizer.grad() but myParam wasn’t even when I run my code over 5 times.
Is there any way to learn my parameter?

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.w = nn.Parameter(torch.randn(1,1,1,1))
        self.myParam = nn.Parameter(torch.rand(1))

    def  forward(self,x):
        self.w.data = torch.where(self.w > 0, (self.w * self.myParam), self.w).data
        # if I use this line
        #      self.w = torch.where(self.w > 0, (self.w * self.myParam), self.w)
        # TypeError: cannot assign 'torch.FloatTensor' as parameter 'w' (torch.nn.Parameter or None expected)  
        return F.conv2d(input,self.w)

net = Net()

print("before ---------------------")
for name,param in net.named_parameters():
    print(name," ",param.data)

input = torch.randn(1,1,2,2)
target = torch.ones(1,1,2,2)
loss_fn = torch.nn.MSELoss(reduction='sum')
optimizer = torch.optim.Adam(net.parameters())

output = net(input)
loss = loss_fn(output,target)
optimizer.zero_grad()    
loss.backward()
optimizer.step()

print("after ---------------------")
for name,param in net.named_parameters():
    print(name," ",param.data)

Actually, I believe that it is the abuse of .data which blocked the gradients. .data is deprecated, you should use something like .detach() instead. .data will return the value in a tensor without the gradients and directly setting .data will not change the gradients. In this case, you should create another temporary variable, like:

tempVar = torch.where(self.w > 0, (self.w * self.myParam), self.w)
return F.conv2d(input, tempVar)

Thak you
If I create another variable as you said, It worked well.
Can I ask one more thing?
I tried to apply this method ( To create variable) to other line which is

quantize_weight = torch.round(tempVar * 4) / 4

but It doesn’t work…

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.w = nn.Parameter(torch.randn(1,1,2,2))
        self.myParam = nn.Parameter(torch.rand(1))

    def myfun(self,w):
        tempVar = torch.where(self.w > 0, (self.w * self.myParam), self.w) # this line worked very well!!!
        # I added this line
        quantize_weight = torch.round(tempVar * 4) / 4
        print(quantize_weight)
        return quantize_weight

    def  forward(self,x):
        q_w = self.myfun(self.w)
        return F.conv2d(input,q_w)

net = Net()

print("before ---------------------")
for name,param in net.named_parameters():
    print(name," ",param.data)

input = torch.randn(1,1,3,3)
target = torch.ones(1,1,2,2)
loss_fn = torch.nn.MSELoss(reduction='sum')
optimizer = torch.optim.Adam(net.parameters())

output = net(input)


loss = loss_fn(output,target)
optimizer.zero_grad()    
loss.backward()
optimizer.step()

print("after ---------------------")
for name,param in net.named_parameters():
    print(name," ",param.data)

result of when I added the line.

before ---------------------
w   tensor([[[[1.1107, 1.0663],
          [1.7609, 0.6599]]]])
myParam   tensor([0.6235])
tensor([[[[0.7500, 0.7500],
          [1.0000, 0.5000]]]], grad_fn=<DivBackward0>)
after ---------------------
w   tensor([[[[1.1107, 1.0663],
          [1.7609, 0.6599]]]])
myParam   tensor([0.6235])

torch.round is a non-differential operation.

I need torch.round function to quantize parameters but it’s not differentiable function as you said.
How can I solve this problem?

Perhaps here? I’m not familiar with quantizing, you probably should post another question for it :joy:.

I solved this problem with this round function.

class RoundNoGradient(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x):
        return x.round()

    @staticmethod
    def backward(ctx, g):
        return g

Thank you!!!

1 Like