Speeding up 7D sliding windows summation

Hi !
I’ve been having troubles with switching an implementation from numpy to Pytorch on a 7 dimensional tensor. The numpy code looks as follow:

@nb.njit
def numba_7d_conv_cpu(gamma, new_gamma, ratio_sq: float, kernel_size: int):
  for b in range(new_gamma.shape[0]):
    for c in range(new_gamma.shape[1]):
      for i in range(new_gamma.shape[2]):
        for j in range(new_gamma.shape[3]):
          for cp in range(new_gamma.shape[1]):
            for k in range(new_gamma.shape[2]):
                for l in range(new_gamma.shape[3]):
                  new_gamma[b,c,i,j,cp,k,l] = ratio_sq * np.sum(gamma[b,c,kernel_size*i:kernel_size*(i+1),kernel_size*j:kernel_size*(j+1),cp,kernel_size*k:kernel_size*(k+1),kernel_size*l:kernel_size*(l+1)])

Notice I make use of Numba.jit in order to achieve a competitive performance. First dimension is the batch size, then it’s repeating twice the number of channels, image width, image height (it is a covariance matrix on the original input).
I’m interested in porting it to Pytorch in order to prevent some back and forth with the GPU/scale the batch size while keeping runtime low.

Using the multiple for loops with pure Pytorch code and Torch.jit do not yield improved runtimes (if I understood right loops aren’t very good with JIT at the moment ?). Should I consider writing CUDA kernels for this task ?

I managed to it by reshaping → convolution → reshaping → convolution → reshaping. I’d be happy to help if anyone encounter the same issue.