How to get gradient of output w.r.t model parameters for all inputs of a batch

Hi all,

I just wanted to ask how I can get the gradient of the output of my network (y) with respect to my model’s parameters (theta) for all values of the input (x). Is this actually possible with PyTorch? I did come across the use of the register_backward_hook method to try something similar although I’m not 100% sure on what’s stored. I’ve attached an example piece of code below to explain what I’m looking for.

Let’s say I have a network that has 1 input node, a hidden layer of 4 hidden nodes, and 2 output nodes. And, I wish to calculate the gradient of each output node with respect to all of the network’s parameters for all input values of x. In this case, I’ve chosen 10 input values for x.

From @pdiddy’s explanation in #14186, it seems that grad_input is used to calculate the forward pass of the layer and is a tuple containing; the bias values of the layer, input data to the layer, and the output of the layer for each input. And, for grad_output it’s a tuple of 1 item which corresponds to the gradient of loss w.r.t to the layer output? So, in my case, it would be dy/d(layer_output). Is this correct?

Because when I read #12331 it confused me a little bit. The example used there is a layer without a bias term, of layer input, x, and layer output, z, with a loss function defined by E (I assume k is the layer index?)

 output z  (grad_output)
     ____|____
    |__layer__|
         |
 input x  (grad_input)

grad_output= [dE/dz^(k)]
grad_input = [dE/dz^(k-1), dE/dw^(k)]

I assume the w in the figure above is the weight matrix of that layer? So, dE/dw^(k) would be (in my case) dy/dw^(k)? However, that has a shape of [1,4] and not of [10,4] corresponding to dy/dw^(k) for all values of x. Why is this is the case? Is this summed over the number of inputs? Is there a way to get this for all values? While I was writing this, would it be possible to use the chain rule to get dy/dw^(k) for all x by using dy/dz^(k) * dz^(k)/dw^(k) where dz^(k)/dw^(k) for a linear layer is just the input, x?

I hope this has made some sense, and if it hasn’t please do say!

Thank you in advance!

import torch
import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):

	def __init__(self, in_features, hid_features, out_features):
		super(Model, self).__init__()
		
		self.fc1 = nn.Linear(in_features, hid_features, bias=True)
		self.fc1_bh = self.fc1.register_backward_hook(self._save_grad_output)
		
		self.af1 = F.softplus
		
		self.fc2 = nn.Linear(hid_features, out_features, bias=False)
		self.fc2_bh = self.fc2.register_backward_hook(self._save_grad_output)		
		
		for name, param in self.named_parameters():
			print("Parameter: ",name)
			param.register_hook(self._save_grad)
		
	def forward(self, x):
		layer1_out = self.fc1(x)
		hid1 = self.af1(layer1_out)
		layer2_out = self.fc2(hid1)
		return layer2_out
		
	def _save_grad(self, grad):
		print("hook: ",grad.size(), grad)
		
	def _save_grad_output(self, module, grad_input, grad_output):
		print("\n module: ",module
		if(len(grad_input)==2):
			print("grad_input[0]: ",grad_input[0].size(), grad_input[0])
			print("grad_input[1]: ",grad_input[1].size(), grad_input[1])
		else:
			print("grad_input[0]: ",grad_input[0].size(), grad_input[0])
			try:
				print("grad_input[1]: ",grad_input[1].size(), grad_input[1])
			except:
				print("grad_input[1]: ",grad_input[1])
			print("grad_input[2]: ",grad_input[2].size(), grad_input[2])
		print("grad_output: ",grad_output[0].size(), grad_output[0])
	

net = Model(in_features=1, 
		    hid_features=4,
		    out_features=2)
						
batch_number=10
x = torch.randn(batch_number,1)
y=net(x)
y.backward(torch.ones_like(y))

Which has a subsequent output of

Parameter:  fc1.weight
Parameter:  fc1.bias
Parameter:  fc2.weight

 module:  Linear(in_features=4, out_features=2, bias=False)
grad_input[0]:  torch.Size([10, 4]) tensor([[-0.2764,  0.4025, -0.6863, -0.4070],
        [-0.2764,  0.4025, -0.6863, -0.4070],
        [-0.2764,  0.4025, -0.6863, -0.4070],
        [-0.2764,  0.4025, -0.6863, -0.4070],
        [-0.2764,  0.4025, -0.6863, -0.4070],
        [-0.2764,  0.4025, -0.6863, -0.4070],
        [-0.2764,  0.4025, -0.6863, -0.4070],
        [-0.2764,  0.4025, -0.6863, -0.4070],
        [-0.2764,  0.4025, -0.6863, -0.4070],
        [-0.2764,  0.4025, -0.6863, -0.4070]])
grad_input[1]:  torch.Size([4, 2]) tensor([[11.8335, 11.8335],
        [ 8.8824,  8.8824],
        [ 7.7202,  7.7202],
        [ 4.1700,  4.1700]])
grad_output:  torch.Size([10, 2]) tensor([[1., 1.],
        [1., 1.],
        [1., 1.],
        [1., 1.],
        [1., 1.],
        [1., 1.],
        [1., 1.],
        [1., 1.],
        [1., 1.],
        [1., 1.]])
hook:  torch.Size([2, 4]) tensor([[11.8335,  8.8824,  7.7202,  4.1700],
        [11.8335,  8.8824,  7.7202,  4.1700]])

 module:  Linear(in_features=1, out_features=4, bias=True)
grad_input[0]:  torch.Size([4]) tensor([-1.8739,  2.3567, -3.6417, -1.3669])
grad_input[1]:  None
grad_input[2]:  torch.Size([1, 4]) tensor([[ 0.0098, -0.1314,  0.9169,  0.4426]])
grad_output:  torch.Size([10, 4]) tensor([[-0.2156,  0.2564, -0.3061, -0.1028],
        [-0.1741,  0.2251, -0.3951, -0.1531],
        [-0.1385,  0.2018, -0.4554, -0.1944],
        [-0.2148,  0.2556, -0.3083, -0.1039],
        [-0.1887,  0.2353, -0.3669, -0.1359],
        [-0.1529,  0.2111, -0.4321, -0.1777],
        [-0.1766,  0.2268, -0.3904, -0.1502],
        [-0.2246,  0.2646, -0.2821, -0.0911],
        [-0.2137,  0.2548, -0.3109, -0.1052],
        [-0.1744,  0.2253, -0.3945, -0.1528]])
hook:  torch.Size([4]) tensor([-1.8739,  2.3567, -3.6417, -1.3669])
hook:  torch.Size([4, 1]) tensor([[ 0.0098],
        [-0.1314],
        [ 0.9169],
        [ 0.4426]])