Hi,
I am currently having some troubles with torch.nn.ParameterList
.
The following code is not working as I expcted.
There are 2 questions below:
- I assume that if some argument of custom autograd functions is a list of tensors, the system does not check
requires_grad
of the tensors in it, is that correct? - When some argument of custom function is a list of tensors, how do I return the gradients for it?
import torch
# Normal matrix multiplication works fine for inputs
# with `requires_grad = False`
weight = torch.nn.Parameter(torch.zeros(4, 4))
x = torch.randn(10, 4)
y = x @ weight
y.sum().backward()
print(weight.grad.max())
class MultiLinearFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, input, weights):
ctx.save_for_backward(input, *weights)
y = []
for weight in weights:
y.append(input @ weight)
return torch.stack(y)
@staticmethod
def backward(ctx, grad):
input, *weights = ctx.saved_tensors
grads_weight = [input.t() @ grad for _ in range(len(weights))]
grads_input = []
for weight in weights:
grads_input.append(grad @ weight.t())
return torch.stack(grads_input).mean(dim=0), None, None
# This unpack statement does not work
# return torch.stack(grads_input).mean(dim=0), *grads_weight
weights = torch.nn.ParameterList([torch.zeros(4, 4) for _ in range(2)])
x = torch.randn(10, 4)
x.requires_grad = True # This is needed for triggering to call backward
y = MultiLinearFunction.apply(x, weights)
y.sum().backward()
environments:
- Python 3.9.7
- PyTorch 1.13.1