Batch-wise Gradient Computation using autograd

Hello,
My loss function includes the partial differential equations (PDE). For each batch, the input size (xy) is of torch.Size([batch, 2, points]) and output (u) is of size torch.Size([batch, 1, points]); which means the prediction is about the response at each point of the batch element. Model is taking the input as torch.Size([batch, channels, points]) and providing output torch.Size([batch, response, points]).
The PDE includes d2u/dx2 and d2u/dy2. Essentially I need the derivative at each point of each element in the batch. Batch dimension is confusing in the calculation. Is there anyone with suggestions? Also, I am trying to avoid loop.

You can compose torch.func.vmap and torch.func.grad to compute per-sample gradients torch.func Whirlwind Tour — PyTorch 2.2 documentation

1 Like