Hi, currently I am trying to get the gradient of the neural network in terms of input (hoping to find the optimal input from the trained model). I have read some posts, such as [How to get "image gradient" in PyTorch?], which pointed out ‘requires_grad = True’ would automatically store the gradient and the gradient can be seen from ‘.grad’ field. However, my problem is a little different from their setting.
First, I need to train the neural network based on the given input datasets. Once the parameters of the network are determined, then I load the trained model to calcualte the gradient with respect to input (32 features), part of the codes are shown as follows:
model = Net()
model.load_state_dict(torch.load('model.pkl'))
bat_size = 3
loss_func = torch.nn.MSELoss() # mean squared error
n_items = len(training_x)
batches_per_epoch = n_items // bat_size
gradient_input = np.zeros(32)
for b in range(batches_per_epoch):
curr_bat = np.random.choice(n_items, bat_size, replace=False)
X = torch.Tensor(training_x[curr_bat])
X.requires_grad_(True)
Y = torch.Tensor(training_y[curr_bat]).view(bat_size,1)
print(X.grad) ## print None
oupt = model(X)
loss_obj = loss_func(oupt, Y)
loss_obj.backward()
print(X.grad) ## it gives me all zeros?
gradient_input = np.add(np.mean(X.grad.numpy(), axis = 0), gradient_input) ## get the average of gradient of all training samples.
where the ‘Net()’ is a neural network structure defined as:
class Net(torch.nn.Module):
def __init__(self):
super(Net, self).__init__()
self.hid1 = torch.nn.Linear(32, 20)
self.hid2 = torch.nn.Linear(20, 10)
self.oupt = torch.nn.Linear(10, 1)
torch.nn.init.xavier_uniform_(self.hid1.weight) # glorot
torch.nn.init.zeros_(self.hid1.bias)
torch.nn.init.xavier_uniform_(self.hid2.weight)
torch.nn.init.zeros_(self.hid2.bias)
torch.nn.init.xavier_uniform_(self.oupt.weight)
torch.nn.init.zeros_(self.oupt.bias)
def forward(self, x):
z = torch.tanh(self.hid1(x))
z = torch.tanh(self.hid2(z))
z = self.oupt(z) # no activation, aka Identity()
return z
When I run my codes, it gives me no errors, but the X.grad are always zeros.
I would really appreciate if someone could give me some suggestions on debugging this issue.