Best way to merge dimensions in pytorch

I am reading through the paper “Crossformer: Transformer Utilizing Cross-Dimension Dependency for Multivariate Time Series Forecasting”

As part of their architecture, they merge together time series through the encoder, so as to attend to more coarse levels of time data. In their github, they do it in the following manner:

# [Dummy Input]
seg_to_merge = []
win_size = 2
x = torch.arange(40).reshape(1,2,10,2)

for i in range(win_size):
    seg_to_merge.append(x[:, :, i::win_size, :])
x = torch.cat(seg_to_merge, -1)

However, this gives the same output as this line:

x = x.reshape(1,2,5,4)

Is there a difference in the two? I try to keep it as simple as possible, and the reshape just seems to be a lot better. The only thing I can think of is that the reshape has some effect on the backprop, or something behind the scenes. If anyone has any better ideas for the name of this question I am also open to changing it. I just am not sure what a tag line for this question would be.

I’m also unsure why the loop with the cat is used as reshape will not break the computation graph so should be fine to use.

1 Like