I have a list of tensor. Each of them have a different shape,I want to apply same function on each tensor。Since each tensor have different shape,I can’t use torch.vmap
. So is there any way to do this instead of using a for-loop or padding each tensor to same shape.
You could use the _foreach_*
methods for your use case, which are used in e.g. optimizers internally:
device = "cuda"
x = [torch.randn(i, device=device) for i in range(1, 10)]
print(x)
# [tensor([1.8445], device='cuda:0'), tensor([0.3656, 1.9296], device='cuda:0'), tensor([ 0.0386, -0.3014, -2.0722], device='cuda:0'), tensor([-2.3759, 1.0400, -0.1989, -1.2025], device='cuda:0'), tensor([-0.5833, 0.6642, -0.4505, 1.3151, 0.2400], device='cuda:0'), tensor([ 0.4314, -0.3071, -1.4707, 0.4154, -1.9433, -0.9677], device='cuda:0'), tensor([-0.5343, 0.5911, 0.7805, -0.2972, 1.7352, 0.1670, 0.4759],
# device='cuda:0'), tensor([ 0.8740, -0.8012, 1.8532, -0.7882, -1.6124, 2.2543, 0.4694, -2.3202],
# device='cuda:0'), tensor([-0.2593, -0.3326, 0.4430, -0.0338, 1.1399, 0.1994, 1.6222, 0.8456,
# -0.1709], device='cuda:0')]
y = torch._foreach_abs(x)
print(y)
# (tensor([1.8445], device='cuda:0'), tensor([0.3656, 1.9296], device='cuda:0'), tensor([0.0386, 0.3014, 2.0722], device='cuda:0'), tensor([2.3759, 1.0400, 0.1989, 1.2025], device='cuda:0'), tensor([0.5833, 0.6642, 0.4505, 1.3151, 0.2400], device='cuda:0'), tensor([0.4314, 0.3071, 1.4707, 0.4154, 1.9433, 0.9677], device='cuda:0'), tensor([0.5343, 0.5911, 0.7805, 0.2972, 1.7352, 0.1670, 0.4759],
# device='cuda:0'), tensor([0.8740, 0.8012, 1.8532, 0.7882, 1.6124, 2.2543, 0.4694, 2.3202],
# device='cuda:0'), tensor([0.2593, 0.3326, 0.4430, 0.0338, 1.1399, 0.1994, 1.6222, 0.8456, 0.1709],
# device='cuda:0'))
Note that these methods are internal and their API could change.
2 Likes
To add to piotr’s answer, if your tensors on differ in shape only on a single dimension you can try NestedTensors. How to apply vmap on a heterogeneous tensor - #2 by soulitzer
Internally nested tensors represent ragged data using a packed representation.
You might be able to get more coverage on certain operators compared to foreach.
1 Like