# Apply tensor-valued function rowwise to get another tensor

I have got a tensor of size `(N,2)`

``````N = 10
t = torch.random(N,2)
``````

and a function `f` that takes two values `x` and `y` and returns a (2,2) tensor, say

``````def f(x,y):
``````

How can I apply the function to each of the `N` rows of the first tensor and store the result as a `(N,2,2)` tensor? Thank you.

I’m unsure if I understand your use case correctly, but I assume you want to get the results of `x=t[0, 0]` and `y=t[0, 1]` for all “rows” of `t`.
If so, you could use `torch.vmap` and rewrite the function a bit, as it will raise an error if a new tensor is created.
This might work:

``````def f(x):
torch.stack((x[0], x[0] + x[1])),
torch.stack((x[0]**2, x[1]**2))),
dim=0)

batched_f = torch.vmap(f)
batched_f(t)
# tensor([[[-1.8799e+00, -2.0183e+00],
#          [ 3.5340e+00,  1.9150e-02]],

#         [[-3.4007e-01, -2.0362e-01],
#          [ 1.1565e-01,  1.8621e-02]],

#         [[ 1.0567e-01,  1.6185e-01],
#          [ 1.1167e-02,  3.1562e-03]],

#         [[ 4.5812e-01,  4.7626e-01],
#          [ 2.0988e-01,  3.2880e-04]],

#         [[ 7.6334e-01,  7.1346e-01],
#          [ 5.8269e-01,  2.4886e-03]],

#         [[-1.9683e-01,  1.4371e+00],
#          [ 3.8741e-02,  2.6697e+00]],

#         [[ 1.1249e+00, -5.1169e-01],
#          [ 1.2654e+00,  2.6784e+00]],

#         [[-3.8157e-02, -1.0671e-01],
#          [ 1.4559e-03,  4.6996e-03]],

#         [[ 1.2200e+00,  2.6654e+00],
#          [ 1.4884e+00,  2.0891e+00]],

#         [[ 7.1752e-01, -1.5422e-01],
#          [ 5.1483e-01,  7.5992e-01]]])

for t_ in t:
print(f(t_))
# tensor([[-1.8799, -2.0183],
#         [ 3.5340,  0.0192]])
# tensor([[-0.3401, -0.2036],
#         [ 0.1157,  0.0186]])
# tensor([[0.1057, 0.1619],
#         [0.0112, 0.0032]])
# tensor([[4.5812e-01, 4.7626e-01],
#         [2.0988e-01, 3.2880e-04]])
# tensor([[0.7633, 0.7135],
#         [0.5827, 0.0025]])
# tensor([[-0.1968,  1.4371],
#         [ 0.0387,  2.6697]])
# tensor([[ 1.1249, -0.5117],
#         [ 1.2654,  2.6784]])
# tensor([[-0.0382, -0.1067],
#         [ 0.0015,  0.0047]])
# tensor([[1.2200, 2.6654],
#         [1.4884, 2.0891]])
# tensor([[ 0.7175, -0.1542],
#         [ 0.5148,  0.7599]])
``````