I noticed that torchvision chooses to use Conv2d to perform squeeze and excitation. However, I noticed that using equivalent Linear is much faster than Conv2d. Is there a reason torchvision chooses Conv2d? Or can this be improved? Below is the code I used to time these two.
import torch
import torch.nn as nn
import time
from torch.backends import cudnn
cudnn.benchmark = True
@torch.no_grad()
def main():
data = torch.randn(64, 512,1,1).float().to("cuda")
shape = data.shape
m1 = nn.Sequential(
nn.Linear(512, 64),
nn.ReLU(True),
nn.Linear(64, 512),
nn.Sigmoid()
).float().to("cuda")
m2 = nn.Sequential(
nn.Conv2d(512, 64, 1),
nn.ReLU(True),
nn.Conv2d(64, 512, 1),
nn.Sigmoid()
).float().to("cuda")
m1[0].weight.data = m2[0].weight.data.squeeze()
m1[0].bias.data = m2[0].bias.data
m1[2].weight.data = m2[2].weight.data.squeeze()
m1[2].bias.data = m2[2].bias.data
res1 = data.squeeze()
res1 = m1(res1)
res1 = res1.reshape(shape)
res2 = m2(data)
print(torch.abs(res1 - res2).sum())
for _ in range(2): # warmup
data = data.squeeze()
data = m1(data)
data = data.reshape(shape)
torch.cuda.synchronize()
t = time.perf_counter()
for _ in range(10000):
data = data.squeeze()
data = m1(data)
data = data.reshape(shape)
torch.cuda.synchronize()
print(time.perf_counter() - t)
for _ in range(2): # warmup
data = m2(data)
torch.cuda.synchronize()
t = time.perf_counter()
for _ in range(10000):
data = m2(data)
torch.cuda.synchronize()
print(time.perf_counter() - t)
if __name__ == "__main__":
main()
output:
tensor(0.0007, device=‘cuda:0’) # total absolute difference, due to float calculation
2.3467748999828473
3.0033348000142723
As shown above, Linear can be much faster.