Difference between gradients from network.parameters() and register_backward_hook()?

What’s the difference between the grad_inputs from register_backward_hook() and the gradients from [x.grad.data for x in network.parameters()]?

class Feedforward(nn.Module):
	def __init__(self, topology):
		super(Feedforward, self).__init__()
		self.input_layer  = nn.Linear(topology['features'], topology['hidden_dim'])
		self.hidden_layer = nn.Linear(topology['hidden_dim'], topology['hidden_dim'])
		self.output_layer = nn.Linear(topology['hidden_dim'], topology['output_dim'])
		self.num_hidden   = topology['hidden_layers']


	def forward(self, x):
		hidden = self.input_layer(x).clamp(min=0)

		for _ in range(self.num_hidden):
			hidden = self.hidden_layer(hidden).clamp(min=0)
			
		return self.output_layer(hidden)


   class Train(object):
	def __init__(self, topology, training):
		self.topology   = topology
		self.training   = training
		self.network    = Feedforward(topology)


	def extract_grads(self):
		return [p.grad.data.numpy()*2 for p in list(self.network.parameters())]


	def update(self, module, grad_in, grad_out):
		grads = self.extract_grads()
		grad_out = grads


	def train(self):
		dh = DataHandler(self.training['data'])

		losses = []
		valid_acc = []
		loss_fn = torch.nn.MSELoss(size_average=False)
		optimizer = torch.optim.Adam(self.network.parameters(), lr=self.training['lr'])

		# training loop
		for x in range(self.training['iterations']):
			batch = dh.get_batch(self.training['batch_size'])
			x = Variable(torch.from_numpy(batch[0]), requires_grad=False)
			y = Variable(torch.from_numpy(batch[1]), requires_grad=False)

			optimizer.zero_grad()
			cost_fn = nn.MSELoss()
			cost = cost_fn(self.network(x), y)
			
			print 'before'
			print list(self.network.parameters())

			self.network.register_backward_hook(self.update)

			cost.backward()

			print 'after'
			print list(self.network.parameters())

			optimizer.step()
			break
2 Likes

Hi,
I am quite confuse about your code, and not sure to understand what it is trying to do.

register_backward_hook() allows you to specify a function that will be called just after the backward pass of a nn.Module has been done. It will provide you with the gradients used as parameter for the backward (the derivative of your loss wrt the output of your Module) and the gradients returned by the backward (the derivative of your loss wrt the input of your Module).

network.parameters() will return all the nn.Parameters contained in the nn.Module (and its children). so [x.grad.data for x in network.parameters()] will create a list of all the gradients for each parameter in your model, namely the derivative of the loss wrt the learnable parameters of your Module.

I hope this helps.

1 Like

I guess my confusion is more of a fundamental understanding of gradients. Are the gradients obtained from [x.grad.data for x in network.parameters()] the same gradients used to update the parameters in gradient descent e.g. theta = theta - lr*gradient, where gradient is the derivative of the cost function w.r.t. the parameters in the network? Sounds like they’re something different - this point seems subtle to me so please correct any misunderstands I have - and that the gradients obtained from register_backwards_pass(hook) are exactly the gradients from theta = theta - lr*gradient, right?

Also say for whatever reason, if I want to manually implement dropout, which one should I use to control gradient updates, register_backwards_pass(hook) or [x.grad.data for x in network.parameters()]?
Like if I want to directly replace in_grad from register_backward_hook(hook), can I simply do something like:

def hook(module, in_grad, out_grad):
   in_grad = my_new_gradient

nn.Linear.register_backward_hook(hook)
1 Like

Consider a layer that has an input i, and output o and some weights w. And suppose you have a loss function L.
If you want to minimize your loss by gradient descent, you want to obtain dL/dw and then update w as w = w - lr * dL/dw.
This dL/dw is computed during the backward pass and is stored in the Variable's .grad field.
That means that [x.grad.data for x in network.parameters()] will give you the list of all the dL/dw associated to each w in your network. If you want to do gradient descent by hand you can do:

for w in network.parameters():
    w.data -= lr * w.grad.data

Which is more or less what is done by sgd here.

If you use a hook, you will get dL/di and dL/do. These are just intermediary computations. They are used to compute the dL/dw during the backward pass using the chain rule as we write for our layer: dL/dw = dL/do * do/dw where dL/do is given by the chain rule applied to the next layer, and do/dw is written by hand for each layer.
The gradients passed to these hook does not correspond to any weight that you are trying to learn.

Does that make it clearer?
You can also take a look at the examples to see how the optimiser should used in practice.

20 Likes

That explanation was perfect. Thanks :slight_smile:

I think through backward hook you could get gradients for (this layer’s input tensor, weight, bias) and the output tensor.

As mentioned in the above comment, register_backward_hook gives you the gradient with respect to the input and output. If you want the gradients for the specific parameters (weights and bias), then you need to use the register_hook method.

Again, register_forward_hook and register_backward_hook are only available for modules while register_hook is available only for parameters