Hi, I’m using PyTorch 0.4.0 and I want to use register_backward_hook to inspect and modify gradients.
My network is a 3-layer lstm, and on top of it an FC layer.
When I register an FC layer in a vanilla feed-forward net, I get a 3-tuple for the input: I’m not sure what each of them represents, but 2 of them have dimensions like the weights and biases of this layer.
when I try to do the same for my network, I see that all gradient array in tuple have the size of the output, and I don’t really know how to access the gradients and what each of them represent.
Code to reproduce the issue:
import torch
import torch.nn as nn
import torch.nn.functional as F
def printgradshape(self, grad_input, grad_output):
for i,g in enumerate(grad_input):
print('grad_input[%d]: ' % i, g.shape)
for i,g in enumerate(grad_output):
print('grad_output[%d]: ' % i, g.shape)
class VanillaConvNet(nn.Module):
def __init__(self):
super(VanillaConvNet, self).__init__()
self.conv1 = nn.Conv2d(1, 10, 5)
self.pool1 = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(10, 20, 5)
self.pool2 = nn.MaxPool2d(2, 2)
self.fc1 = nn.Linear(320, 50)
self.fc2 = nn.Linear(50, 10)
def forward(self, input):
x = self.pool1(F.relu(self.conv1(input)))
x = self.pool2(F.relu(self.conv2(x)))
x = x.view(x.size(0), -1)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
return x
net = VanillaConvNet()
print(net)
input = torch.randn(1, 1, 28, 28)
target = torch.tensor([3], dtype=torch.long)
loss_fn = nn.CrossEntropyLoss() # LogSoftmax + ClassNLL Loss
net.fc2.register_backward_hook(printgradshape)
out = net(input)
err = loss_fn(out, target)
err.backward()
class VanillaRecNet(nn.Module):
def __init__(self, input_size=6, hidden_size=32, num_layers=3, output_size=3):
super(VanillaRecNet, self).__init__()
self.lstm = nn.LSTM(input_size=input_size, hidden_size=hidden_size, num_layers=num_layers)
self.fc1 = nn.Linear(hidden_size, output_size)
self.hidden = \
(torch.autograd.Variable(torch.zeros(num_layers, 1, hidden_size)),
torch.autograd.Variable(torch.zeros(num_layers, 1, hidden_size)))
def forward(self, input):
rec_out, self.hidden = self.lstm(input, self.hidden)
out = self.fc1(rec_out)
return out
net = VanillaRecNet()
print(net)
input = torch.randn(200, 1, 6)
target = torch.randn(200, 1, 3)
loss_fn = nn.L1Loss() # for regression
net.fc1.register_backward_hook(printgradshape)
out = net(input)
err = loss_fn(out, target)
err.backward()
output:
VanillaConvNet(
(conv1): Conv2d(1, 10, kernel_size=(5, 5), stride=(1, 1))
(pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(conv2): Conv2d(10, 20, kernel_size=(5, 5), stride=(1, 1))
(pool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(fc1): Linear(in_features=320, out_features=50, bias=True)
(fc2): Linear(in_features=50, out_features=10, bias=True)
)
('grad_input[0]: ', (1, 10))
('grad_input[1]: ', (1, 50))
('grad_input[2]: ', (50, 10))
('grad_output[0]: ', (1, 10))
VanillaRecNet(
(lstm): LSTM(6, 32, num_layers=3)
(fc1): Linear(in_features=32, out_features=3, bias=True)
)
('grad_input[0]: ', (200, 1, 3))
('grad_input[1]: ', (200, 1, 3))
('grad_output[0]: ', (200, 1, 3))