Is it possible to compute custom gradients for all parameter in a ParameterDict and return them as e.g. another dict in a custom backward pass?
class AFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, x, weights):
ctx.x = x
ctx.weights = weights
return 2*x
@staticmethod
def backward(ctx, grad_output):
grad_x = ...
# ParameterDict weights does not need a gradient, only its entries
# Is there a way to return or even compute the gradients for the parameter in a dict?
return grad_x, grad_weights
class A(nn.Module):
def __init__(self):
self.weights = torch.nn.ParameterDict('a': nn.Parameter(torch.randn(3)),
'b': nn.Parameter(torch.randn(3)))
Custom autograd Function does not consider Tensors in structures as participating in autograd, you’ll need to write a wrapper around your custom autograd Function to first unpack the tensors from that dict before calling into autograd Function, and pack the tensors back into the structure after. After writing a lot of boilerplate of converting dict to list and back it could look like the following:
Thank you @soulitzer for your quick reply!
Your idea is a very clever way to manage this. Unfortunately, I dynamically add new parameters in the forward pass and I want to take these into account in the backward pass. So I might want to calculate more gradients than I pass parameters to the forward function. So far I have solved the problem by setting the gradients using weights['a'].grad = ..., but I am not sure whether this works exactly as I wish.