how can I calculate the complexity O of torch.cat and torch.stack ?
and thank you
They allocate a target blob of the required size and then copy the values, so it is proportional to the total size. In degenerate cases (0-sized tensors) you would also have things that perform with the number of operands, for non-degenerate cases this is bounded by the total size.
Strictly speaking, the number of dimensions also plays a role, but in terms of asymptotic analysis you would soon discover that there is a hard upper bound for the number of dimensions you can use with PyTorch anyway.
The most common pitfall to avoid is to grow tensors by using growing_list = cat(growing_list, new_item)
in a loop because that will give quadratic complexity (in the iteration size from having effort proportional to i in the i-th loop), which becomes noticeable quite fast.
Best regards
Thomas
For more clarity, I want to calculate the time complexity O of torch.cat and torch.stack through those examples
Example 1 concatenation by torch.cat:
import torch
T1 = torch.Tensor([1,2,3,4])
T2 = torch.Tensor([0,3,4,1])
T3 = torch.Tensor([4,3,2,5])
T = torch.cat((T1,T2,T3))
Example 2 of torch.stack:
import torch
T1 = torch.Tensor([1,2,3,4])
T2 = torch.Tensor([0,3,4,1])
T3 = torch.Tensor([4,3,2,5])
T = torch.stack((T1,T2,T3))
I’m not sure what you are asking here.
If you have fixed tensors, you’ll have a fixed amount of time this will take. You could measure it, e.g. using profiling.
When you have many / large tensors, it would be roughly scalingas above.
Best regards
Thomas
No, I don’t have fixed tensors. But at each iteration, I have a number of tensors which I concatenate either by torch.cat, or by torch.stack. for your above explanation, I can’t understand it. Is the time complexity worth O (n) or O (n ^ 2)?
Each operation is proportional to the size.
If you have a loop that concatenates fixed-size tensors in each iteration, then you have “iters * fixed” which is linear in the iters.
If you have a loop that grows a tensor linearly by repeated concatenation, you have something like fixed + 2 * fixed + 3 * fixed + … + iters * fixed = fixed * iters * (iters + 1) / 2.
Thank you for these useful explanations