class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.w = nn.Parameter(torch.randn(1,1,3,3))
self.myParam = nn.Parameter(torch.rand(1))
def myfun(self,w):
tempVar = torch.where( abs(self.w) >=0.3, self.w**self.myParam, self.w)
return tempVar
def forward(self,x):
w = self.myfun(self.w)
return F.conv2d(input, w)
net = Net()
input = torch.randn(1,1,50,50)
target = torch.ones(1,1,48,48)
loss_fn = torch.nn.MSELoss(reduction='sum')
optimizer = torch.optim.Adam(net.parameters(),lr=0.0001)
for i in range(1000):
print("before ---------------------")
for name,param in net.named_parameters():
print(name," ",param.data)
output = net(input)
loss = loss_fn(output,target)
optimizer.zero_grad()
loss.backward()
optimizer.step()
print("after ---------------------")
for name,param in net.named_parameters():
print(name," ",param.data)
time.sleep(0.5)
I got nan values after loss backward()
before ---------------------
w tensor([[[[ 1.0108, 0.8858, -0.5732],
[-1.1128, -0.8122, -0.7874],
[ 0.1894, 0.4039, 0.6233]]]])
myParam tensor([0.5857])
after ---------------------
w tensor([[[[nan, nan, nan],
[nan, nan, nan],
[nan, nan, nan]]]])