Actually ran into a bit of issue again, it seems like nn.linear doesn’t like higher than 2 dimension input:
Here is the pseudo code:
input_matrix[ X x Y x Z x Channels]
linear_layer = nn.linear( in_channel, out_channel)
output_matrix = linear_layer(input_matrix)