How to pass a 3D tensor to Linear layer?

I have a 3D tensor (5x9x12) I want to cast it to a (5x9x1) tensor through the linear layer. But I found that the nn.LinearLayer require that the input should be a matrix instead of a 3d tensor. How can I achieve my task?

2 Likes

You can do that with Tensor.view()

x = x.view(-1, 12)
b_size = x.size(0)
x = linear(x)
x = x.view(b_size, -1, 1)
2 Likes

Thanks, it is likely to solve my problem. If there is a single function to to the trick, it would be more clear.

Based on Linear layer documentation, you simply need to pass the 3d input through a linear(12,1) layer.

7 Likes

Could you explain how to do this in a non iterative manner? I think it would be very helpful

@Dimitrisl: as @unnatjain said, if you pass a (5, 9, 12) tensor through a linear(12, 1) layer, it will automatically apply the linear transformation only onto the last dimension, giving you a (5, 9, 1) tensor. No iteration required.

9 Likes