I’m unsure if I understand your use case correctly, but I assume you want to get the results of x=t[0, 0]
and y=t[0, 1]
for all “rows” of t
.
If so, you could use torch.vmap
and rewrite the function a bit, as it will raise an error if a new tensor is created.
This might work:
def f(x):
return torch.stack((
torch.stack((x[0], x[0] + x[1])),
torch.stack((x[0]**2, x[1]**2))),
dim=0)
batched_f = torch.vmap(f)
batched_f(t)
# tensor([[[-1.8799e+00, -2.0183e+00],
# [ 3.5340e+00, 1.9150e-02]],
# [[-3.4007e-01, -2.0362e-01],
# [ 1.1565e-01, 1.8621e-02]],
# [[ 1.0567e-01, 1.6185e-01],
# [ 1.1167e-02, 3.1562e-03]],
# [[ 4.5812e-01, 4.7626e-01],
# [ 2.0988e-01, 3.2880e-04]],
# [[ 7.6334e-01, 7.1346e-01],
# [ 5.8269e-01, 2.4886e-03]],
# [[-1.9683e-01, 1.4371e+00],
# [ 3.8741e-02, 2.6697e+00]],
# [[ 1.1249e+00, -5.1169e-01],
# [ 1.2654e+00, 2.6784e+00]],
# [[-3.8157e-02, -1.0671e-01],
# [ 1.4559e-03, 4.6996e-03]],
# [[ 1.2200e+00, 2.6654e+00],
# [ 1.4884e+00, 2.0891e+00]],
# [[ 7.1752e-01, -1.5422e-01],
# [ 5.1483e-01, 7.5992e-01]]])
for t_ in t:
print(f(t_))
# tensor([[-1.8799, -2.0183],
# [ 3.5340, 0.0192]])
# tensor([[-0.3401, -0.2036],
# [ 0.1157, 0.0186]])
# tensor([[0.1057, 0.1619],
# [0.0112, 0.0032]])
# tensor([[4.5812e-01, 4.7626e-01],
# [2.0988e-01, 3.2880e-04]])
# tensor([[0.7633, 0.7135],
# [0.5827, 0.0025]])
# tensor([[-0.1968, 1.4371],
# [ 0.0387, 2.6697]])
# tensor([[ 1.1249, -0.5117],
# [ 1.2654, 2.6784]])
# tensor([[-0.0382, -0.1067],
# [ 0.0015, 0.0047]])
# tensor([[1.2200, 2.6654],
# [1.4884, 2.0891]])
# tensor([[ 0.7175, -0.1542],
# [ 0.5148, 0.7599]])