Implementing Linear as Conv2d

Let’s say that we have a 4D tensor of shape (B, C, H, W). We can implement a Linear layer in two ways as follows:

lin = nn.Linear(C, K)
con = nn.Conv2d(C, K, kernel_size=1)

In practice they are equivalent, only having different input and output shapes: for lin it is (B, H, W, C) -> (B, H, W, K), whereas for con it is (B, C, H, W) -> (B, K, H, W).

After running some tests (on CPU) I found out that the con is faster than lin, so I was wondering what is the reason for this difference in speed and why would someone prefer using nn.Linear at all?

P.s. similarly for a 3D tensor we can use Conv1d to implement a Linear layer.

1 Like