I have a task where I perform an operation over individual elements of the network gradient.
For a fully-connected layer, I have an m x n
gradient matrix G. I have an equivalent m x n
update matrix P.
Elements of G and P map one-to-one i.e. that P_{ij} dictates how to update G_{ij}
I have gotten this to work with a nested loop, but that was extremely slow. I managed to improve the speed with the multiprocessing starmap method, but I want to do this in a better, more direct way, as this is still slow and takes a long time.
Is there a PyTorch way to perform this operation in parallel? I understand that this is theoretically possible, but I have not been able to figure this out. I have also tried PyTorch’s vmap
function with no success.
Here’s a re-creatable example of what I want to be able to achieve.
import torch
import numpy as np
def update_gradient(g, p):
dummy = 1
if p > 1:
dummy += 100
else:
dummy -= 9
return dummy
m, n = 10, 15
gradient = torch.rand(m, n)
pulses = torch.rand(m, n)
new_list = []
# instead of this nested for loop - I want to map my update_gradient method in parallel to each element of the gradient and pulse matrix
for i in range(0, gradient.shape[0]):
for j in range(0, gradient.shape[1]):
new_list.append(update_gradient(gradient[i][j], pulses[i][j]))
new_list = torch.Tensor(new_list)
new_list = torch.reshape(new_list, gradient.shape)