Nn.linear with multi-dimensional channel datas

Actually ran into a bit of issue again, it seems like nn.linear doesn’t like higher than 2 dimension input:

Here is the pseudo code:

input_matrix[ X x Y x Z x Channels]

linear_layer = nn.linear( in_channel, out_channel)

output_matrix = linear_layer(input_matrix)