Hi everybody,
I want to track intermediate gradients in the computational graph.
There have been related questions on this as in
Yet the solution to both problems were applied to fairly simple and straight forward computation graphs.
I’m interested in wrapping the tracking of the intermediate gradients in an optimizer class such that I collect the intermediate non-leaf gradients and process them with more complicated preconditioning methods in the fancy_optim.step() function.
Down below is a minimum working example in which I want to store the gradients of the class variable self.intermediate gradient.
The tracking of parameters is relatively straight forward and I’m aware that the two matrix multiplies can be combined to one.
Is there some way to save the intermediate gradients in an optimizer class through hooks?
Or maybe some other way to track intermediate gradients through the optimizer class?
PS: Is my understanding correct that the self.state variable in the optimizer class stores all relevant values for the optimization of the parameters?
import torch
import torch.nn.functional as F
class ExampleLayer(torch.nn.Module):
def __init__(self, _dim_input, _dim_intermediate, _dim_output):
super().__init__()
self.matrix1 = torch.nn.Parameter(torch.randn(_dim_input, _dim_intermediate))
# I want to store and track these gradients
self.intermediate_gradient = torch.randn(_dim_input, _dim_output).requires_grad_()
# But only this and matrix1 are parameters
self.matrix2 = torch.nn.Parameter(torch.randn(_dim_intermediate, _dim_output))
def forward(self, _input: torch.Tensor):
out = torch.mm(_input, self.matrix1)
self.intermediate_gradient = torch.mm(out, self.matrix2)
self.intermediate_gradient.retain_grad()
return self.intermediate_gradient
dim_input = 21
dim_intermediate = 11
dim_output = 31
x = torch.randn(100, dim_input) # Batch size of 100
layer = ExampleLayer(dim_input, dim_intermediate, dim_output)
optim = torch.optim.Adam(layer.parameters())
saved_grads = []
def save_grad():
def savegrad_hook(grad):
print('inside savegrad_hook')
saved_grads.append(grad)
print('inside function save_grad')
return savegrad_hook
layer.intermediate_gradient.register_hook(save_grad())
for epoch in range(3):
out = layer(x)
F.mse_loss(out, torch.ones_like(out)).backward()
print(layer.intermediate_gradient.grad)
print(saved_grads)
# print(layer.intermediate_gradient.grad)
Thanks!