Assuming your tensors are uint8 (images with values between 0-255 inclusive), it’s much faster to convert to Numpy, then stack in Numpy, then convert to torch.
Random tensors:
tensor_list = [torch.randint(size=(3, 120, 1024), low=0, high=255, dtype=torch.uint8) for _ in range(4096)]
Stacking in torch:
def stack_via_torch1(tensor_list):
return torch.stack(tensor_list)
def stack_via_torch2(tensor_list):
stacked_ten = torch.empty((len(tensor_list), *tensor_list[0].shape), dtype=torch.uint8)
for i, ten in enumerate(tensor_list):
stacked_ten[i] = ten
return stacked_ten
%timeit stack_via_torch1(tensor_list)
946 ms ± 1.28 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit stack_via_torch2(tensor_list)
972 ms ± 1.51 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
Stacking via NumPy (~1.8x faster for 4096 tensors):
def stack_via_numpy1(tensor_list):
return torch.from_numpy(np.array([ten.numpy() for ten in tensor_list]))
def stack_via_numpy2(tensor_list):
np_data: np.ndarray = np.empty(len(tensor_list), dtype=object)
for i, ten in enumerate(tensor_list):
np_data[i] = tensor_list[i]
return torch.from_numpy(np.stack(np_data))
def stack_via_numpy3(tensor_list):
np_data: np.ndarray = np.empty((len(tensor_list), *tensor_list[0].shape), dtype=np.uint8)
for i, ten in enumerate(tensor_list):
np_data[i] = ten
return torch.from_numpy(np_data)
def stack_via_numpy4(tensor_list):
np_data: np.ndarray = np.empty(len(tensor_list), dtype=object)
for i, ten in enumerate(tensor_list):
np_data[i] = ten.numpy()
return torch.from_numpy(np.stack(np_data))
def stack_via_numpy5(tensor_list):
np_data: np.ndarray = np.empty(len(tensor_list), dtype=object)
for i, ten in enumerate(tensor_list):
np_data[i] = tensor_list[i]
np_empty = np.empty((len(tensor_list), *tensor_list[0].shape), dtype=np.uint8)
return torch.from_numpy(np.stack(np_data, out=np_empty))
%timeit stack_via_numpy1(tensor_list)
526 ms ± 2.19 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit stack_via_numpy2(tensor_list)
533 ms ± 2.17 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit stack_via_numpy3(tensor_list)
533 ms ± 2.23 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit stack_via_numpy4(tensor_list)
528 ms ± 1.31 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit stack_via_numpy5(tensor_list)
534 ms ± 1.42 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
Thus stacking via Numpy is ~1.8x faster for 4096 tensors.
Additional:
- I prefer
stack_via_numpy1
since it’s the fastest and easy to remember.
- I observed this level of speedup even for a small number of tensors (1 to 1024) and upto 65k. An exception was in the range of 64 to 256 tensors, where both torch and numpy took about the same time weirdly.