Implementing a Fully Connected Layer using nn.Conv2d vs nn.Linear

It is possible to implement a fully connected layer either using nn.Linear or by using nn.Conv with the kernel_size equal to the input size. The following piece of code demonstrates that we get identical results using both approaches. Is there any advantage to using one approach over the other? By advantage, I mean number of parameters, memory use, speed etc?

import torch
import torch.nn as nn

# Define input
inp = 1000*torch.rand(1,10,3,1)

# Define fc layer using nn.Linear and initialize weights to 1 and biases to 0
fc_lin = nn.Linear(30,4) # 10x3 = 30 input channels
nn.init.ones_(fc_lin.weight)
nn.init.zeros_(fc_lin.bias)
fc_lin.eval();

# Define fc layer using nn.Conv and set kernel_size equal to input dimensions and initialize weights to 1 and biases to 0
fc_conv = nn.Conv2d(10,4,kernel_size=(3,1))
nn.init.ones_(fc_conv.weight)
nn.init.zeros_(fc_conv.bias)
fc_conv.eval();

# Run input through fc linear layers and add dummy dimensions
out_lin = fc_lin(torch.flatten(inp,1))[:,:,None,None]

# Run input through fc conv layer
out_conv = fc_conv(inp)

# Assert outputs are equal
print(torch.equal(out_lin, out_conv))

The answer to this will depend on your particular hardware setup and the software libraries available; some conv implementations will dispatch to matmul in which case performance might be similar. However, my guess is that most libraries for convolution e.g., cuDNN will not optimize for this use case as it is somewhat unusual whereas it would be a more typical use case for libraries like cuBLAS.

TLDR; it will ultimately depend on your setup and the shapes, so try them both if you are curious

Another difference is that, if you want to apply BatchNorm1d to the FC output, using Conv1d is much more convenient than using Linear. This is because Conv1d has output shape [B, C, N] which is the input shape for BN1d, while Linear has output shape [B, N, C]. This is why in many point cloud deep learning papers people use Conv1d to implement FC (e.g. PointNet)