I have a network that incorporates a learnable gradient descent step and is implemented as follows:

Implementation #1:

``````import torch
from torch import nn

loss_v = loss(v)

def __init__(self):
self.A = nn.Parameter(torch.rand(100, 50))
self.rho = nn.Parameter(torch.rand(1,))

def forward(self, x_init, y):
grad = cal_grad(x_init, lambda z: 0.5 * (z.mm(self.A) - y).pow(2).sum(dim=[1,]))  # calculate gradient implicitly (without giving its formula)
x_hat = x_init - self.rho * grad
return x_hat

x_hat = net(x_init, y)

loss = (x_hat - x_gt).pow(2).mean()  # L2 loss
loss.backward()

``````

It is not that common that I use a learnable matrix `self.A` to calculate gradient and obtain `x_hat`, in its forwarding process. I am confused by that why after `loss.backward()`, the gradient of parameter `self.A` is `None`?

And I re-implement the original network by explicitly giving the analytic gradient calculation formula, the parameter `self.A` can be successfully learned since `net.A.grad` will not be `None`:

Implementation #2:

``````import torch
from torch import nn

class Net_explicit(nn.Module):
def __init__(self):
super(Net_explicit, self).__init__()
self.A = nn.Parameter(torch.rand(100, 50))
self.rho = nn.Parameter(torch.rand(1,))

def forward(self, x_init, y):
x_hat = x_init - self.rho * grad
return x_hat

net = Net_explicit()

x_hat = net(x_init, y)

loss = (x_hat - x_gt).pow(2).mean()  # L2 loss
loss.backward()

``````

1. The second implementation works well as my expectation, but why the first one does not train `self.A` (with `None` gradient)? I guess that the `torch.autograd.grad` itself may not be differentiable for its internal variables, since the external learnable gradient descent step size `self.rho` and input `x_init` obtain their own `grad`s.
2. To implement more complicated pipelines, in which `self.A` may be even a large network, what can I do to make the implementation #1 works well as #2 (i.e., to make `self.A` be trainable with `autograd` mechanisms)?
3. By the way, I am not sure if the `grad` of `x_init` calculated in the forwardings of #1 and #2 are identical or totally equal. Could you please check this for me?
Note: The first dimensions of `x_init`, `y` and `x_gt`, with the size of 10, are the batch dimensions. The version of my `torch` is `1.7.1`, and the `torch.autograd.grad` does not have the parameter `is_grads_batched`.
It seems that I solved this problem by modifying the returning code of function `cal_grad` to be:
``````    return torch.autograd.grad(loss_v, v, torch.ones_like(loss_v), create_graph=True, retain_graph=True)[0]