Hi All,
I just have a quick question regarding parallelism and calculation of the Laplacian within PyTorch. As noted with a previous question I asked, I managed to get some code to calculate the Laplacian of a feed-forward like model which is defined as R^N -> R^1 function. The original question is here if anyone is curious.
In short, I have a feedforward like model which takes in an input Tensor x
of size [S,N]
where S
is the size of the batch and N
is the number of nodes in the input layer. The code that calculated it was,
y = model(x) #where model is an R^N to R^1 function
laplacian = torch.zeros(x.shape[0]) #array to store values of laplacian
for i, xi in enumerate(x):
hess = torch.autograd.functional.hessian(model, xi.unsqueeze(0), create_graph=True)
laplacian[i] = torch.diagonal(hess.view(N, N) offset=0).sum()
For my particular problem, I need to compute multi batches in parallel pass them through my network and take a mean of all batches. The way I’ve chosen to represent this is by having an input Tensor of size [S,B,N]
where S
, B
, N
are the number of samples, number of batches, and number of inputs respectively.
x = get_inputs() #returns Tensor of shape [S,B,N]
laplacian_per_batch = torch.zeros(x.shape[0:2]) #array to store laplacian for each batch
for j in range(x.shape[1]): #cycle over each batch
xj = x[:,j,:] #input of batch j
yj = model(xj) #output of batch j
laplacianj = torch.zeros(xj.shape[0]) #laplacian array for batch j
for i, xi in enumerate(xj): #cycle over each element within batch j
hess = torch.autograd.hessian(model, xi.unsqueeze(0), create_graph=True) #calculate
laplacian[i] = torch.diagonal(hess.view(N, N), offset=0).sum() #laplacian
laplacian_per_batch[:,j] = laplacian #store laplacian of batch j in total array
I know from previous questions that batch-vectorization of calculating the hessian is currently a work in progress, but I was wondering if it were possible to run each “batch loop” (the outer loop, enumerated by j) in parallel? This is because a single pass through loop “i” takes around 10s, but given I usually need anywhere from 8 to 32 batches the runtime for a single pass becomes somewhat impractical.
I assume this is possible as the individual batches aren’t connected to one another so, in theory, I could reshape x
to [S*B,N]
but that would just compress the 2 nested loops in an equal-sized single loop, rather than doing each batch in parallel.
Thank you in advance!