Track Intermediate Gradient within Module

Hi everybody,

I want to track intermediate gradients in the computational graph.

There have been related questions on this as in


Yet the solution to both problems were applied to fairly simple and straight forward computation graphs.

I’m interested in wrapping the tracking of the intermediate gradients in an optimizer class such that I collect the intermediate non-leaf gradients and process them with more complicated preconditioning methods in the fancy_optim.step() function.

Down below is a minimum working example in which I want to store the gradients of the class variable self.intermediate gradient.

The tracking of parameters is relatively straight forward and I’m aware that the two matrix multiplies can be combined to one.

Is there some way to save the intermediate gradients in an optimizer class through hooks?
Or maybe some other way to track intermediate gradients through the optimizer class?

PS: Is my understanding correct that the self.state variable in the optimizer class stores all relevant values for the optimization of the parameters?

import torch
import torch.nn.functional as F

class ExampleLayer(torch.nn.Module):

	def __init__(self, _dim_input, _dim_intermediate, _dim_output):

		super().__init__()

		self.matrix1 = torch.nn.Parameter(torch.randn(_dim_input, _dim_intermediate))

		# I want to store and track these gradients
		self.intermediate_gradient = torch.randn(_dim_input, _dim_output).requires_grad_()

		# But only this and matrix1 are parameters
		self.matrix2 = torch.nn.Parameter(torch.randn(_dim_intermediate, _dim_output))

	def forward(self, _input: torch.Tensor):

		out = torch.mm(_input, self.matrix1)

		self.intermediate_gradient = torch.mm(out, self.matrix2)
		self.intermediate_gradient.retain_grad()

		return self.intermediate_gradient

dim_input = 21
dim_intermediate = 11
dim_output = 31

x = torch.randn(100, dim_input) # Batch size of 100

layer = ExampleLayer(dim_input, dim_intermediate, dim_output)
optim = torch.optim.Adam(layer.parameters())

saved_grads = []

def save_grad():

	def savegrad_hook(grad):
		print('inside savegrad_hook')
		saved_grads.append(grad)
	print('inside function save_grad')
	return savegrad_hook

layer.intermediate_gradient.register_hook(save_grad())

for epoch in range(3):
	out = layer(x)
	F.mse_loss(out, torch.ones_like(out)).backward()

	print(layer.intermediate_gradient.grad)

print(saved_grads)

# print(layer.intermediate_gradient.grad)

Thanks!

Okay, I was able to solve the problem by realizing that the register_hook command for non-leaf nodes has to be placed precisely between forward and backward pass.

In comparison the register_forward_hook and register_backward_hook can be initialized before the forward pass.

While it makes a lot of sense in hindsight, maybe it could be noted in the docs?