Some questions about custom autograd function with parameter list


I am currently having some troubles with torch.nn.ParameterList.
The following code is not working as I expcted.

There are 2 questions below:

  1. I assume that if some argument of custom autograd functions is a list of tensors, the system does not check requires_grad of the tensors in it, is that correct?
  2. When some argument of custom function is a list of tensors, how do I return the gradients for it?
import torch

# Normal matrix multiplication works fine for inputs
# with `requires_grad = False`
weight = torch.nn.Parameter(torch.zeros(4, 4))
x = torch.randn(10, 4)
y = x @ weight

class MultiLinearFunction(torch.autograd.Function):
    def forward(ctx, input, weights):
        ctx.save_for_backward(input, *weights)

        y = []
        for weight in weights:
            y.append(input @ weight)
        return torch.stack(y)

    def backward(ctx, grad):
        input, *weights = ctx.saved_tensors

        grads_weight = [input.t() @ grad for _ in range(len(weights))]
        grads_input = []
        for weight in weights:
            grads_input.append(grad @ weight.t())

        return torch.stack(grads_input).mean(dim=0), None, None

        # This unpack statement does not work
        # return torch.stack(grads_input).mean(dim=0), *grads_weight

weights = torch.nn.ParameterList([torch.zeros(4, 4) for _ in range(2)])
x = torch.randn(10, 4)
x.requires_grad = True  # This is needed for triggering to call backward

y = MultiLinearFunction.apply(x, weights)


  • Python 3.9.7
  • PyTorch 1.13.1