I have a network that incorporates a learnable gradient descent step and is implemented as follows:
Implementation #1:
import torch
from torch import nn
def cal_grad(x, loss):
v = torch.autograd.Variable(x, requires_grad=True)
loss_v = loss(v)
return torch.autograd.grad(loss_v, v, torch.ones_like(loss_v))[0]
class Net_autograd(nn.Module):
def __init__(self):
super(Net_autograd, self).__init__()
self.A = nn.Parameter(torch.rand(100, 50))
self.rho = nn.Parameter(torch.rand(1,))
def forward(self, x_init, y):
grad = cal_grad(x_init, lambda z: 0.5 * (z.mm(self.A)  y).pow(2).sum(dim=[1,])) # calculate gradient implicitly (without giving its formula)
x_hat = x_init  self.rho * grad
return x_hat
net = Net_autograd()
x_init = torch.rand(10, 100, requires_grad=True)
y = torch.rand(10, 50, requires_grad=False)
x_gt = torch.rand(10, 100, requires_grad=False)
x_hat = net(x_init, y)
loss = (x_hat  x_gt).pow(2).mean() # L2 loss
loss.backward()
print('A', net.A.grad)
print('rho', net.rho.grad)
print('x_init', x_init.grad)
It is not that common that I use a learnable matrix self.A
to calculate gradient and obtain x_hat
, in its forwarding process. I am confused by that why after loss.backward()
, the gradient of parameter self.A
is None
?
And I reimplement the original network by explicitly giving the analytic gradient calculation formula, the parameter self.A
can be successfully learned since net.A.grad
will not be None
:
Implementation #2:
import torch
from torch import nn
class Net_explicit(nn.Module):
def __init__(self):
super(Net_explicit, self).__init__()
self.A = nn.Parameter(torch.rand(100, 50))
self.rho = nn.Parameter(torch.rand(1,))
def forward(self, x_init, y):
grad = (x_init.mm(self.A)  y).mm(self.A.t()) # calculate gradient explicitly
x_hat = x_init  self.rho * grad
return x_hat
net = Net_explicit()
x_init = torch.rand(10, 100, requires_grad=True)
y = torch.rand(10, 50, requires_grad=False)
x_gt = torch.rand(10, 100, requires_grad=False)
x_hat = net(x_init, y)
loss = (x_hat  x_gt).pow(2).mean() # L2 loss
loss.backward()
print('A', net.A.grad)
print('rho', net.rho.grad)
print('x_init', x_init.grad)
I want to ask that:

The second implementation works well as my expectation, but why the first one does not train
self.A
(withNone
gradient)? I guess that thetorch.autograd.grad
itself may not be differentiable for its internal variables, since the external learnable gradient descent step sizeself.rho
and inputx_init
obtain their owngrad
s. 
To implement more complicated pipelines, in which
self.A
may be even a large network, what can I do to make the implementation #1 works well as #2 (i.e., to makeself.A
be trainable withautograd
mechanisms)? 
By the way, I am not sure if the
grad
ofx_init
calculated in the forwardings of #1 and #2 are identical or totally equal. Could you please check this for me?
Note: The first dimensions of x_init
, y
and x_gt
, with the size of 10, are the batch dimensions. The version of my torch
is 1.7.1
, and the torch.autograd.grad
does not have the parameter is_grads_batched
.