Hi
I want to make a convolution with 1x1x4x4 kernel, and for reason I cannot understand, it takes more time that make a convolution with 16x1x4x4 kernel
Can you please help me understand why?
import time
input = torch.randn(64,1,28,28)
kernels = torch.tensor([[[[ 1., 1., 1., 1.],
[ 1., 1., 1., 1.],
[ 1., 1., 1., 1.],
[ 1., 1., 1., 1.]],
[[ 1., 1., -1., -1.],
[ 1., 1., -1., -1.],
[ 1., 1., -1., -1.],
[ 1., 1., -1., -1.]],
[[ 1., -1., -1., 1.],
[ 1., -1., -1., 1.],
[ 1., -1., -1., 1.],
[ 1., -1., -1., 1.]],
[[ 1., -1., 1., -1.],
[ 1., -1., 1., -1.],
[ 1., -1., 1., -1.],
[ 1., -1., 1., -1.]],
[[-1., -1., -1., -1.],
[-1., -1., -1., -1.],
[ 1., 1., 1., 1.],
[ 1., 1., 1., 1.]],
[[-1., -1., 1., 1.],
[-1., -1., 1., 1.],
[ 1., 1., -1., -1.],
[ 1., 1., -1., -1.]],
[[-1., 1., 1., -1.],
[-1., 1., 1., -1.],
[ 1., -1., -1., 1.],
[ 1., -1., -1., 1.]],
[[-1., 1., -1., 1.],
[-1., 1., -1., 1.],
[ 1., -1., 1., -1.],
[ 1., -1., 1., -1.]],
[[ 1., 1., 1., 1.],
[-1., -1., -1., -1.],
[-1., -1., -1., -1.],
[ 1., 1., 1., 1.]],
[[ 1., 1., -1., -1.],
[-1., -1., 1., 1.],
[-1., -1., 1., 1.],
[ 1., 1., -1., -1.]],
[[ 1., -1., -1., 1.],
[-1., 1., 1., -1.],
[-1., 1., 1., -1.],
[ 1., -1., -1., 1.]],
[[ 1., -1., 1., -1.],
[-1., 1., -1., 1.],
[-1., 1., -1., 1.],
[ 1., -1., 1., -1.]],
[[-1., -1., -1., -1.],
[ 1., 1., 1., 1.],
[-1., -1., -1., -1.],
[ 1., 1., 1., 1.]],
[[-1., -1., 1., 1.],
[ 1., 1., -1., -1.],
[-1., -1., 1., 1.],
[ 1., 1., -1., -1.]],
[[-1., 1., 1., -1.],
[ 1., -1., -1., 1.],
[-1., 1., 1., -1.],
[ 1., -1., -1., 1.]],
[[-1., 1., -1., 1.],
[ 1., -1., 1., -1.],
[-1., 1., -1., 1.],
[ 1., -1., 1., -1.]]]]).reshape(16,1,4,4)
first_kernel = kernels[0,0].reshape(1,1,4,4).clone().detach()
start_time = time.time()
for i in range(1000):
after_conv = F.conv2d(input, kernels, stride=[1], padding=[1])
end_time = time.time()
print("--- %s seconds ---" % (end_time - start_time))
start_time = time.time()
for i in range(1000):
after_conv = F.conv2d(input, first_kernel, stride=[1], padding=[1])
end_time = time.time()
print("--- %s seconds ---" % (end_time - start_time))
Output:
— 2.058915853500366 seconds —
— 2.1254472732543945 seconds —
Thank you