How to calculate Laplacian for multiple batches in parallel

Hi All,

I just have a quick question regarding parallelism and calculation of the Laplacian within PyTorch. As noted with a previous question I asked, I managed to get some code to calculate the Laplacian of a feed-forward like model which is defined as R^N -> R^1 function. The original question is here if anyone is curious.

In short, I have a feedforward like model which takes in an input Tensor x of size [S,N] where S is the size of the batch and N is the number of nodes in the input layer. The code that calculated it was,

y = model(x) #where model is an R^N to R^1 function

laplacian = torch.zeros(x.shape[0]) #array to store values of laplacian

for i, xi in enumerate(x):
    hess = torch.autograd.functional.hessian(model, xi.unsqueeze(0), create_graph=True)
    laplacian[i] = torch.diagonal(hess.view(N, N) offset=0).sum()

For my particular problem, I need to compute multi batches in parallel pass them through my network and take a mean of all batches. The way I’ve chosen to represent this is by having an input Tensor of size [S,B,N] where S, B, N are the number of samples, number of batches, and number of inputs respectively.

x = get_inputs() #returns Tensor of shape [S,B,N]

laplacian_per_batch = torch.zeros(x.shape[0:2]) #array to store laplacian for each batch

for j in range(x.shape[1]): #cycle over each batch
  xj = x[:,j,:]     #input of batch j
  yj = model(xj)    #output of batch j
  laplacianj = torch.zeros(xj.shape[0]) #laplacian array for batch j
  for i, xi in enumerate(xj): #cycle over each element within batch j
    hess = torch.autograd.hessian(model, xi.unsqueeze(0), create_graph=True)  #calculate
    laplacian[i] = torch.diagonal(hess.view(N, N), offset=0).sum()            #laplacian
  laplacian_per_batch[:,j] = laplacian #store laplacian of batch j in total array

I know from previous questions that batch-vectorization of calculating the hessian is currently a work in progress, but I was wondering if it were possible to run each “batch loop” (the outer loop, enumerated by j) in parallel? This is because a single pass through loop “i” takes around 10s, but given I usually need anywhere from 8 to 32 batches the runtime for a single pass becomes somewhat impractical.

I assume this is possible as the individual batches aren’t connected to one another so, in theory, I could reshape x to [S*B,N] but that would just compress the 2 nested loops in an equal-sized single loop, rather than doing each batch in parallel.

Thank you in advance! :slight_smile: