Is ``torch.autograd.grad" differentiable for its internal variables?

I have a network that incorporates a learnable gradient descent step and is implemented as follows:


Implementation #1:

import torch
from torch import nn

def cal_grad(x, loss):
    v = torch.autograd.Variable(x, requires_grad=True)
    loss_v = loss(v)
    return torch.autograd.grad(loss_v, v, torch.ones_like(loss_v))[0]

class Net_autograd(nn.Module):
    def __init__(self):
        super(Net_autograd, self).__init__()
        self.A = nn.Parameter(torch.rand(100, 50))
        self.rho = nn.Parameter(torch.rand(1,))

    def forward(self, x_init, y):
        grad = cal_grad(x_init, lambda z: 0.5 * (z.mm(self.A) - y).pow(2).sum(dim=[1,]))  # calculate gradient implicitly (without giving its formula)
        x_hat = x_init - self.rho * grad
        return x_hat

net = Net_autograd()
x_init = torch.rand(10, 100, requires_grad=True)
y = torch.rand(10, 50, requires_grad=False)
x_gt = torch.rand(10, 100, requires_grad=False)

x_hat = net(x_init, y)

loss = (x_hat - x_gt).pow(2).mean()  # L2 loss
loss.backward()

print('A', net.A.grad)
print('rho', net.rho.grad)
print('x_init', x_init.grad)

It is not that common that I use a learnable matrix self.A to calculate gradient and obtain x_hat, in its forwarding process. I am confused by that why after loss.backward(), the gradient of parameter self.A is None? :thinking:

And I re-implement the original network by explicitly giving the analytic gradient calculation formula, the parameter self.A can be successfully learned since net.A.grad will not be None:


Implementation #2:

import torch
from torch import nn

class Net_explicit(nn.Module):
    def __init__(self):
        super(Net_explicit, self).__init__()
        self.A = nn.Parameter(torch.rand(100, 50))
        self.rho = nn.Parameter(torch.rand(1,))

    def forward(self, x_init, y):
        grad = (x_init.mm(self.A) - y).mm(self.A.t())  # calculate gradient explicitly
        x_hat = x_init - self.rho * grad
        return x_hat

net = Net_explicit()
x_init = torch.rand(10, 100, requires_grad=True)
y = torch.rand(10, 50, requires_grad=False)
x_gt = torch.rand(10, 100, requires_grad=False)

x_hat = net(x_init, y)

loss = (x_hat - x_gt).pow(2).mean()  # L2 loss
loss.backward()

print('A', net.A.grad)
print('rho', net.rho.grad)
print('x_init', x_init.grad)

I want to ask that:

  1. The second implementation works well as my expectation, but why the first one does not train self.A (with None gradient)? I guess that the torch.autograd.grad itself may not be differentiable for its internal variables, since the external learnable gradient descent step size self.rho and input x_init obtain their own grads.

  2. To implement more complicated pipelines, in which self.A may be even a large network, what can I do to make the implementation #1 works well as #2 (i.e., to make self.A be trainable with autograd mechanisms)? :cold_sweat:

  3. By the way, I am not sure if the grad of x_init calculated in the forwardings of #1 and #2 are identical or totally equal. Could you please check this for me?

Note: The first dimensions of x_init, y and x_gt, with the size of 10, are the batch dimensions. The version of my torch is 1.7.1, and the torch.autograd.grad does not have the parameter is_grads_batched.

It seems that I solved this problem by modifying the returning code of function cal_grad to be:

    return torch.autograd.grad(loss_v, v, torch.ones_like(loss_v), create_graph=True, retain_graph=True)[0]