Hi everyone,
I’m trying to optimize a network with a custom module and architecture, I’ll attach the code below. The nodes have specific behaviour where positive and negative weights have different effects; I modelled this by using two different weight tensors and using a mask for the gradient.
When running the code we get that the gradient of parameters has a lot of nans. Note that when printing the full gradient matrix only some columns are nans and not the whole matrix, this makes me feel like the nans propagate from a first error
[tensor(0.0267), tensor(0.0106), tensor(nan), tensor(nan), tensor(nan), tensor(nan), tensor(nan), tensor(nan), tensor(nan), tensor(nan), tensor(5.1160e-06), tensor(nan), tensor(4.0687e-06), tensor(nan)]
I have tried using torch.autograd.set_detect_anomaly(True) to debug, and it raises AddBackward0 at the line sum_f_in = torch.matmul(x**2, self.weight_in**2)+self.bias_in**2. But after manually checking there is no overflow error or anything. Even when setting self.weight_in=0 and self.bias_in=0 it still raises the same error.
So I’m lost at what else to try and figure out where the error comes from, any help would be greatly appreciated, thanks!
here is a code to reproduce the error:
import torch.nn as nn
import torch
torch.manual_seed(0)
torch.autograd.set_detect_anomaly(True)
DEVICE = 'cuda:0' if torch.cuda.is_available() else 'cpu'
class Node(nn.Module):
#node of the netwrok, take a matrxi with n features in and return one with m features
def __init__(self, weight, bias):
super().__init__()
self._c = .015436882530272481
self._fth = 100.
self._fmax = 4100
self._in_a = .03051282
self._in_b = 35.4594017
self._out_a = 5.419425e-5
#only weights superior to 0, weight_in are weights<0
self.weight_ex = nn.Parameter(torch.where(weight>0, weight.abs(), 0))
self._mask_ex_w = (weight>0).to(DEVICE)#mask for gradient
self.weight_in = nn.Parameter(torch.where(weight<0, weight.abs(), 0))
self._mask_in_w = (weight<0).to(DEVICE)
self.bias_ex = nn.Parameter(torch.where(bias>0, bias.abs(), 0))
self._mask_ex_b = (bias>0).to(DEVICE)
self.bias_in = nn.Parameter(torch.where(bias<0, bias.abs(), 0))
self._mask_in_b = (bias<0).to(DEVICE)
self.set_hook()
def forward(self, x):
'''Input of shape (batch_dim, time_dim, 2)'''
#reducing inputs
sum_f_ex = torch.matmul(x**2, self.weight_ex**2)+self.bias_ex**2
sum_f_in = torch.matmul(x**2, self.weight_in**2)+self.bias_in**2
f_ex = torch.sqrt(sum_f_ex)
f_in = torch.sqrt(sum_f_in)
diff = f_ex-self._fth
xc = self._c*torch.sqrt(torch.where(diff>0, diff, 0))+f_in*self._out_a
res = torch.where(f_ex>(f_in*self._in_a+self._in_b), xc, 0)
return res
def set_hook(self):
#register hooks for gradient masking
def _hook_ex_w(grad):
return grad*self._mask_ex_w
def _hook_in_w(grad):
return grad*self._mask_in_w
def _hook_ex_b(grad):
return grad*self._mask_ex_b
def _hook_in_b(grad):
return grad*self._mask_in_b
self.weight_ex.register_hook(_hook_ex_w)
self.weight_in.register_hook(_hook_in_w)
self.bias_ex.register_hook(_hook_ex_b)
self.bias_in.register_hook(_hook_in_b)
class Network(nn.Module):
#creatwe the network
def __init__(self, shapes):
super().__init__()
self.all_nodes = nn.Sequential()
for i in range(len(shapes)):
#build nodes one after the other
self.all_nodes.append(Node(
torch.FloatTensor(*shapes[i]).uniform_(-1500, 1500),
torch.FloatTensor(shapes[i][1]).uniform_(-500, 500)))
#params for sigmoid
self._decision_line = nn.Parameter(torch.tensor(.7).float())
self._slope = nn.Parameter(torch.tensor(1.).float())
def forward(self, x):
_, time_dim, _ = x.shape
res = self.all_nodes(x)#propagate through node
out = torch.sum(res, dim=[1, 2])/time_dim#average over time
return torch.sigmoid(self._slope*(out-self._decision_line))#activation function
if __name__=='__main__':
#create model and data
shapes = [[12, 7], [7, 2], [2, 1]]
data = torch.FloatTensor(32, 210, 12).uniform_(0, 1).to(DEVICE)
labels = (torch.rand(32)<.5).float().to(DEVICE)
model = Network(shapes).to(DEVICE)
loss_fn = nn.BCELoss()
#run model and propagate loss
preds = model(data).ravel()
loss = loss_fn(preds, labels)
loss.backward()
print([torch.norm(p.grad.cpu()) for p in model.parameters()])