Hi all,
New to this forum and really look forward to your help!
My question is about how to perform for-looped backward more efficiently
say I have a input x
in shape of (b, d)
where b
is batch_size
and d
is dimension
and Q
is a trained neural network that converts d
to scalar real value to provide gradient for x
, i.e. -Q(z, x).backward()
where z
is another conditional value
Now I wanna compute gradient of each x
in b
independently, how could I do this in a efficient way?
An naive method is like:
import torch
from torch import nn
import numpy as np
torch.manual_seed(11111)
np.random.seed(11111)
class Net(nn.Module):
def __init__(self, in_feature) -> None:
super().__init__()
self.net = nn.Sequential(
nn.Linear(in_feature, 5),
nn.ReLU(),
nn.Linear(5, 5),
nn.ReLU(),
nn.Linear(5, 1),
)
def forward(self, x):
return self.net(x)
net = Net(515)
z = torch.randn((10, 3, 512))
a = torch.randn((10, 3, 3))
a.grad = None
a.requires_grad = True
s = torch.cat([z, a], dim=-1)
for i in range(s.shape[0]):
net.zero_grad()
net(s[i]).sum((-2, -1)).backward()
_grad2 = a.grad.flatten().clone()
print(_grad2)
But this is of course inefficient which I hope to optimize away
A current method I try now is
import torch
from torch import nn
import numpy as np
torch.manual_seed(11111)
np.random.seed(11111)
class Net(nn.Module):
def __init__(self, in_feature) -> None:
super().__init__()
self.net = nn.Sequential(
nn.Linear(in_feature, 5),
nn.ReLU(),
nn.Linear(5, 5),
nn.ReLU(),
nn.Linear(5, 1),
)
def forward(self, x):
return self.net(x)
net = Net(515)
z = torch.randn((10, 3, 512))
a = torch.randn((10, 3, 3))
a.grad = None
a.requires_grad = True
s = torch.cat([z, a], dim=-1)
for i in range(s.shape[0]):
net.zero_grad()
net(s[i]).sum((-2, -1)).backward()
_grad2 = a.grad.flatten().clone()
print(_grad2)
net = Net(515)
z = torch.randn((10, 3, 512))
a = torch.randn((10, 3, 3))
a.grad = None
a.requires_grad = True
s = torch.cat([z, a], dim=-1)
_out = net(s).sum((-2, -1))
# Flaw method, doesn't give correct results
_grad3 = torch.autograd.grad(
outputs=_out,
inputs=a,
grad_outputs=torch.ones_like(_out),
create_graph=True,
)
_grad3 = _grad3[0].clone()
print(_grad3)
_grad3 = _grad3.flatten()
assert torch.allclose(_grad2, _grad3)
print("end")
But the assert fails and _grad2
is not close to _grad3
Big thanks for any help and advice!