Unable to increase performance for custom CNN using torchscript

Hello,
I tried to visualize the runtime performance improvement made by convolution layer which I implemented from scratch Vs torchscript version of convolution layer Vs torch.nn.conv2d() module for 100 iterations with input (128,3,28,28), out_channel =64, kernel size=3.

Convolution layer from scratch in CUDA -> 9.366 seconds
torchscript convolution layer from scratch in CUDA -> 6.636 seconds
torch.nn.conv2d() -> 475.614 milliseconds.

Is there any problem in my approach? and how to optimize even more? . I request you to help me with this problem.

My code

class conv2D(nn.Module):
  def __init__(self, in_channel, out_channel, kernel_size):
    super(conv2D,self).__init__()
    self.weight = torch.nn.Parameter(torch.ones(out_channel,in_channel,kernel_size, kernel_size))
    self.bias = torch.nn.Parameter(torch.zeros(out_channel))
    self.kernel_size = kernel_size
    self.in_channel = in_channel
    self.out_channel = out_channel

  def forward(self, image):
    img_height = image.shape[3]
    img_width = image.shape[2]
    batch_size = image.shape[0]
    out_height = img_height-self.kernel_size+1
    out_width = img_width-self.kernel_size+1

    output = torch.zeros(batch_size,self.out_channel,out_width,out_height)
    for k in range(batch_size):
      for i in range(out_height):
        for j in range(out_width):
          temp = torch.sum(image[k,:,j:j+self.kernel_size,i:i+self.kernel_size]*self.weight,dim=(1,2,3))
          output[k,:,i,j]=torch.add(temp,self.bias)
      return output
x = torch.ones(128,3,28,28).to("cuda")
c = conv2D(3,64,3).to("cuda")
c_s = torch.jit.script(c).to("cuda")
c_s(x)

Scripting the model and running with a sample input to get an optimized graph

with torch.autograd.profiler.profile(use_cuda=True) as prof:
  with torch.no_grad():
    for i in range(100):
      c(x)
print(prof.table())

Profiling both the scripted and normal method.

with torch.autograd.profiler.profile(use_cuda=True) as prof:
  with torch.no_grad():
    for i in range(100):
      c_s(x)
print(prof.table())

Convolutions are not easy to optimize and I’m not sure if there is any JIT compiler in the wild, which is currently able to optimize these layers to a competitive performance (please let me know, if you find it :wink: ).