Compute custom gradients of parameter in ParameterDict

Hi all,

Is it possible to compute custom gradients for all parameter in a ParameterDict and return them as e.g. another dict in a custom backward pass?

class AFunction(torch.autograd.Function):
    def forward(ctx, x, weights):
        ctx.x = x
        ctx.weights = weights
        return 2*x
    def backward(ctx, grad_output):
        grad_x = ...
        # ParameterDict weights does not need a gradient, only its entries
        # Is there a way to return or even compute the gradients for the parameter in a dict?
        return grad_x, grad_weights

class A(nn.Module):
    def __init__(self):
        self.weights = torch.nn.ParameterDict('a': nn.Parameter(torch.randn(3)), 
                                              'b': nn.Parameter(torch.randn(3)))

Thank you!

Custom autograd Function does not consider Tensors in structures as participating in autograd, you’ll need to write a wrapper around your custom autograd Function to first unpack the tensors from that dict before calling into autograd Function, and pack the tensors back into the structure after. After writing a lot of boilerplate of converting dict to list and back it could look like the following:

import torch

def to_dict(ks, vs):
    return {k: v for k, v in zip(ks, vs)}

def to_list(dct):
    return zip(*[(k, v) for k, v in dct.items()])

class Func(torch.autograd.Function):
    def forward(ctx, *args):
        ks, *vs = args
        ctx.ks = ks
        inp = to_dict(ks, vs)
        return inp["a"] * inp["b"]

    def backward(ctx, grad_out):
        ks = ctx.ks
        vs = ctx.saved_tensors
        inp = to_dict(ks, vs)
        grad_dict = {
            "a": inp["b"] * grad_out,
            "b": inp["a"] * grad_out
        _, gvs = to_list(grad_dict)
        return None, *gvs

def func(in_dict):
    in_keys, in_list = to_list(in_dict)
    return Func.apply(in_keys, *in_list)

inp = {
    "a": torch.tensor(1., requires_grad=True),
    "b": torch.tensor(2., requires_grad=True)

out = func(inp)
print(inp["a"].grad, inp["b"].grad)

Does that work for your case?

Thank you @soulitzer for your quick reply!
Your idea is a very clever way to manage this. Unfortunately, I dynamically add new parameters in the forward pass and I want to take these into account in the backward pass. So I might want to calculate more gradients than I pass parameters to the forward function. So far I have solved the problem by setting the gradients using weights['a'].grad = ..., but I am not sure whether this works exactly as I wish.