Nn.linear with multi-dimensional channel datas

I was trying to implement an MLP layer that takes a 3-dimensional data, but only process data on one axis only (so the other two dimensions are considered channels in this case). I’ve done a quick look around on the documentation page and couldn’t seem to find anything similar to what I want. Granted, I could just call nn.linear separately but that will greatly slow down the layer since they won’t be parallelized. If anyone can point toward something that might meet my need that would be awesome!

Could you post a small pseudo code, please?
nn.Linear accepts additional dimensions as [batch_size, *, in_features], so you might be able to directly use it.

1 Like

Thanks, that should work out.

Actually ran into a bit of issue again, it seems like nn.linear doesn’t like higher than 2 dimension input:

Here is the pseudo code:

input_matrix[ X x Y x Z x Channels]

linear_layer = nn.linear( in_channel, out_channel)

output_matrix = linear_layer(input_matrix)

It seems to work for me in 1.7.0.dev20200830:

x, y, z, channels = 2, 3, 4, 5
input = torch.randn((x, y, z, channels))
lin = nn.Linear(channels, 6)
out = lin(input)
print(out.shape)
> torch.Size([2, 3, 4, 6])
1 Like

Aha I screwed up the batch-normalization part instead,thanks for the reply thou! (Gotah be a little more careful going through the errors haha)