Is there examples of using the parameter is_grads_batched in torch.autograd.grad? Thanks!
If I understand the usage correctly, you could avoid writing a for loop as internally vmap
will be used:
x = torch.randn(2, 2, requires_grad=True)
# Scalar outputs
out = x.sum() # Size([])
batched_grad = torch.arange(3) # Size([3])
grad, = torch.autograd.grad(out, (x,), (batched_grad,), is_grads_batched=True)
# loop approach
grads = torch.stack(([torch.autograd.grad(out, x, torch.tensor(a))[0] for a in range(3)]))
(example uses the internal test as the base)
Thanks a lot for the quick reply! That is exactly the answer I was looking for!
May I follow up and ask whether this parameter can help give me a batch of jacobians? For example, see the commented code below
x = torch.randn(2, 2) #This is my input of size 2 and batch_size 2
out = model(x) #I want Jacobian wrt the logits, so let's say out is of size 2 (batch_size) x 10 (logits)
vs = torch.eye(M) #This is the vector in VJP, I need to vmap 10 times to obtain the full jacobian
vs = torch.stack([vs]*2, 0).permute(1,0,2) #Because my batch size is 2
grad, = torch.autograd.grad(out, model.parameters(), (vs,), is_grads_batched=True)
If my model has p parameters, I would expect this code to give me jacobian of size 2(batch_size) x 10 (logits) x p (parameters). However, the batch_size dimension is always missing in the jacobian output, so I am not sure whether I am doing something wrong?
Hello, I would like to ask if Is_grads_batched
can help resolve the computational bottleneck Iām facing. The problem is that I have num_data
points in a 2D space, and for each point, I obtain the output [num_data, ens_size]
from ens_size
neural networks. I am trying to compute the first and second derivatives. The current implementation is as follows, but it runs quite slowly.
def pde_fn(x, u):
"""
:param x: the inputs shape = [num_data, 1, dim] requires_grad=True
:param u: the u(x) shape = [num_data, ens_size] from ens_size neural networks
:return: f(x) shape = [num_data, ens_size] ,right of the physical model
"""
u_x1 = []
u_x2 = []
u_x1x1 = []
u_x1x2 = []
u_x2x1 = []
u_x2x2 = []
for i in range(u.shape[1]):
u_x_i = torch.autograd.grad(u[:, i].sum(), x, create_graph=True)[0] # shape = [num_data, 1, 2]
u_x1_i = u_x_i[..., 0] # shape = [num_data, 1] and assume that the dim == 2
u_x2_i = u_x_i[..., 1] # shape = [num_data, 1] and assume that the dim == 2
# Calculate the second derivative
u_x1x1_i = torch.autograd.grad(u_x1_i.sum(), x, create_graph=True)[0][..., 0] # shape = [num_data, 1]
u_x1x2_i = torch.autograd.grad(u_x1_i.sum(), x, create_graph=True)[0][..., 1] # shape = [num_data, 1]
u_x2x1_i = torch.autograd.grad(u_x2_i.sum(), x, create_graph=True)[0][..., 0] # shape = [num_data, 1]
u_x2x2_i = torch.autograd.grad(u_x2_i.sum(), x, create_graph=True)[0][..., 1] # shape = [num_data, 1]
u_x1.append(u_x1_i)
u_x2.append(u_x2_i)
u_x1x1.append(u_x1x1_i)
u_x1x2.append(u_x1x2_i)
u_x2x1.append(u_x2x1_i)
u_x2x2.append(u_x2x2_i)
u_x1 = torch.cat(u_x1, dim=-1) # shape = [num_data, ens_size]
u_x2 = torch.cat(u_x2, dim=-1) # shape = [num_data, ens_size]
u_x1x1 = torch.cat(u_x1x1, dim=-1) # shape = [num_data, ens_size]
u_x1x2 = torch.cat(u_x1x2, dim=-1) # shape = [num_data, ens_size]
u_x2x1 = torch.cat(u_x2x1, dim=-1) # shape = [num_data, ens_size]
u_x2x2 = torch.cat(u_x2x2, dim=-1) # shape = [num_data, ens_size]