Applying a function on each individual element of a Tensor

I have a task where I perform an operation over individual elements of the network gradient.

For a fully-connected layer, I have an m x n gradient matrix G. I have an equivalent m x n update matrix P.

Elements of G and P map one-to-one i.e. that P_{ij} dictates how to update G_{ij}

I have gotten this to work with a nested loop, but that was extremely slow. I managed to improve the speed with the multiprocessing starmap method, but I want to do this in a better, more direct way, as this is still slow and takes a long time.

Is there a PyTorch way to perform this operation in parallel? I understand that this is theoretically possible, but I have not been able to figure this out. I have also tried PyTorch’s vmap function with no success.

Here’s a re-creatable example of what I want to be able to achieve.

import torch
import numpy as np


def update_gradient(g, p):
    dummy = 1

    if p > 1:
        dummy += 100
    else:
        dummy -= 9

    return dummy


m, n = 10, 15

gradient = torch.rand(m, n)
pulses = torch.rand(m, n)

new_list = []


# instead of this nested for loop - I want to map my update_gradient method in parallel to each element of the gradient and pulse matrix
for i in range(0, gradient.shape[0]):
    for j in range(0, gradient.shape[1]):
        new_list.append(update_gradient(gradient[i][j], pulses[i][j]))


new_list = torch.Tensor(new_list)
new_list = torch.reshape(new_list, gradient.shape)

Hi Osama!

In general, if you want to apply a function element-wise to the elements
of a pytorch tensor and that function is built up of “straightforward” pieces,
it will usually be possible to rewrite that function in terms of pytorch tensor
operations that work on the tensor as a whole (element-wise) without using
loops.

Here’s an illustration that compares your example code with a loop-free
version:

>>> import torch
>>> torch.__version__
'1.10.2'
>>>
>>> _ = torch.manual_seed (2022)
>>>
>>> def update_gradient(g, p):
...     dummy = 1
...
...     if p > 1:
...         dummy += 100
...     else:
...         dummy -= 9
...
...     return dummy
...
>>> m, n = 10, 15
>>>
>>> gradient = torch.rand(m, n)
>>> pulses = torch.rand(m, n)
>>>
>>> new_list = []
>>>
>>> # instead of this nested for loop - I want to map my update_gradient method in parallel to each element of the gradient and pulse matrix
... for i in range(0, gradient.shape[0]):
...     for j in range(0, gradient.shape[1]):
...         new_list.append(update_gradient(gradient[i][j], pulses[i][j]))
...
>>> new_list = torch.Tensor(new_list)
>>> new_list = torch.reshape(new_list, gradient.shape)
>>>
>>> # loop-free, pytorch-tensor-function version
... def update_gradientB (g, p):
...     dummy = torch.ones_like (g)
...     dummy[p  > 1] += 100
...     dummy[p <= 1] -= 9
...     return dummy
...
>>> resultB = update_gradientB (gradient, pulses)
>>>
>>> resultB
tensor([[-8., -8., -8., -8., -8., -8., -8., -8., -8., -8., -8., -8., -8., -8.,
         -8.],
        [-8., -8., -8., -8., -8., -8., -8., -8., -8., -8., -8., -8., -8., -8.,
         -8.],
        [-8., -8., -8., -8., -8., -8., -8., -8., -8., -8., -8., -8., -8., -8.,
         -8.],
        [-8., -8., -8., -8., -8., -8., -8., -8., -8., -8., -8., -8., -8., -8.,
         -8.],
        [-8., -8., -8., -8., -8., -8., -8., -8., -8., -8., -8., -8., -8., -8.,
         -8.],
        [-8., -8., -8., -8., -8., -8., -8., -8., -8., -8., -8., -8., -8., -8.,
         -8.],
        [-8., -8., -8., -8., -8., -8., -8., -8., -8., -8., -8., -8., -8., -8.,
         -8.],
        [-8., -8., -8., -8., -8., -8., -8., -8., -8., -8., -8., -8., -8., -8.,
         -8.],
        [-8., -8., -8., -8., -8., -8., -8., -8., -8., -8., -8., -8., -8., -8.,
         -8.],
        [-8., -8., -8., -8., -8., -8., -8., -8., -8., -8., -8., -8., -8., -8.,
         -8.]])
>>>
>>> torch.equal (resultB, new_list)
True

Best.

K. Frank

1 Like

I understand your solution - but I guess my problem then is that the function I want to write in terms of PyTorch operations isn’t really re-writable as PyTorch operations (at least from what I have been able to work out).

I have asked a follow-up question here particular to my problem - thanks for your input @KFrank !