I have data with the following form [B, T, N, F], where each word stands for batch, time, node and features respectively. I pass this tensor through many modules and make some permutations to the dimensions, so I will try to explain them as best I can.
1- I pass the tensor through a forwarding module that extracts spatio-temporal representations in forwarding time and outputs [B, T, N, 1]. I squeeze and apply a permutation to this tensor to obtain [B, N, T].
2- I pass the tensor through a similar module but this one specialised in extracting backward spatio-temporal representations. Again, the output is [B, N, T].
3- I mix the information extracted by the two modules by making a torch.cat, obtaining a tensor with the form [B, N, 2*T].
4- Finally, I apply a MLP to the last tensor to obtain a prediction of shape [B, N, T] that combines the information obtained from the representations extracted by the two modules of points 1 and 2.
I did this process for a project I started and left incomplete a few months ago; however, I am reviewing it and I have the feeling that these permutations do not make sense. What am I missing?
A couple questions.
- When you say a forwarding module, do you mean a Linear Layer, or a 1d/2d convolution layer?
On a side note, a Linear layer won’t extract spatial-temporal representations effectively, if at all. Work through a few matmul examples in a spreadsheet, and it should become clear why that operation is essentially order blind.
Convolutions, on the other hand, are far more suitable for that job. RNNs are also sufficient, but suffer data loss on longer sequences.
Are you applying any attention layers?
Are you using positional/temporal embeddings?
As far as permuting the time and node dims, it’s hard to say what that is/isn’t accomplishing without knowing the layer types involved.
Thank you very much for your answer. I will answer you point by point to keep the order of the topics.
The module that receives as input the shape tensor [B, T, N, F] is a recurrent GNN. Essentially, it is composed of two submodules: The first which is an RNN, and the second which implements a GCN.
I do not use any attention layers. Only RNN + GCN layers in the first two modules and finally linear layers to blend the previously extracted representation.
I am not sure if I understood what you mean by positional/temporal embeddings, but I do not apply any further transformations to tensors able to operate on temporal and abstract data apart from the first two modules.
I hope this post can shed some light on my problem setup. Thank you very much again!