[Question] How to perform batched autograd?

Hi all,

New to this forum and really look forward to your help!

My question is about how to perform for-looped backward more efficiently

say I have a input x in shape of (b, d) where b is batch_size and d is dimension and Q is a trained neural network that converts d to scalar real value to provide gradient for x, i.e. -Q(z, x).backward() where z is another conditional value

Now I wanna compute gradient of each x in b independently, how could I do this in a efficient way?

An naive method is like:

import torch
from torch import nn
import numpy as np

torch.manual_seed(11111)
np.random.seed(11111)


class Net(nn.Module):
    def __init__(self, in_feature) -> None:
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_feature, 5),
            nn.ReLU(),
            nn.Linear(5, 5),
            nn.ReLU(),
            nn.Linear(5, 1),
        )

    def forward(self, x):
        return self.net(x)


net = Net(515)

z = torch.randn((10, 3, 512))
a = torch.randn((10, 3, 3))
a.grad = None
a.requires_grad = True

s = torch.cat([z, a], dim=-1)
for i in range(s.shape[0]):
    net.zero_grad()
    net(s[i]).sum((-2, -1)).backward()
_grad2 = a.grad.flatten().clone()
print(_grad2)

But this is of course inefficient which I hope to optimize away

A current method I try now is

import torch
from torch import nn
import numpy as np

torch.manual_seed(11111)
np.random.seed(11111)


class Net(nn.Module):
    def __init__(self, in_feature) -> None:
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_feature, 5),
            nn.ReLU(),
            nn.Linear(5, 5),
            nn.ReLU(),
            nn.Linear(5, 1),
        )

    def forward(self, x):
        return self.net(x)


net = Net(515)

z = torch.randn((10, 3, 512))
a = torch.randn((10, 3, 3))
a.grad = None
a.requires_grad = True

s = torch.cat([z, a], dim=-1)
for i in range(s.shape[0]):
    net.zero_grad()
    net(s[i]).sum((-2, -1)).backward()
_grad2 = a.grad.flatten().clone()
print(_grad2)

net = Net(515)

z = torch.randn((10, 3, 512))
a = torch.randn((10, 3, 3))
a.grad = None
a.requires_grad = True
s = torch.cat([z, a], dim=-1)
_out = net(s).sum((-2, -1))

# Flaw method, doesn't give correct results
_grad3 = torch.autograd.grad(
    outputs=_out,
    inputs=a,
    grad_outputs=torch.ones_like(_out),
    create_graph=True,
)

_grad3 = _grad3[0].clone()

print(_grad3)
_grad3 = _grad3.flatten()
assert torch.allclose(_grad2, _grad3)
print("end")

But the assert fails and _grad2 is not close to _grad3

Big thanks for any help and advice!