As far as I am aware, pytorch does not have this kind of “map”
However, pytorch supports many different functions that act
element-wise on tensors (arithmetic, cos(), log(), etc.). If you
can rewrite your function using element-wise torch tensor
operations, your composite function will also act element-wise,
and will do what you want.
Numpy provides a way to vectorize a function. Examples for the same makes it very clear and easy to understand. I am not able to find a similar thing in PyTorch. A reference to any of the following would be really helpful:
How to use map() with PyTorch tensors?
Is there any API like np.vectorize?
PS: We want to apply a function f on each element of list of tensors.
batch_size, feature_size = 3, 5
v = torch.randn(batch_size, feature_size)
# remove mean
return feature_vec - feature_vec.mean()
result = functorch.vmap(simple_row_func)(v)
# equivalent to
result = v - v.mean(1, keepdim=True)