Helix Convolution in Pytorch

I currently investigate the development of a convolutional neural network involving up to 5 or 6 dimensional arrays efficiently.

I was aware that many of the tools used for convolutional neural networks do not really deal with ND convolutions, so I decided to try and write an implementation of Helix Convolution, whereby the convolution can be treated as a large, 1D convolution (see Reference 1. http://sepwww.stanford.edu/public/docs/sep95/jon1/paper_html/node2.html , Reference 2 https://sites.ualberta.ca/~mostafan/Files/Papers/md_convolution_TLE2009.pdf for more details of the concept).

I did this under the (possibly incorrect) assumption that a large, single dimensional convolution was likely to be easier on a GPU than a multidimensional one, as well as that the method is trivially scalable to N dimensions.

Particularly, a quote from Reference 2. states:

We have not found important gains in computational efficiency between N-D standard convolution versus using the
algorithm described in the text. We have, however, found that
writing codes for seismic data regularization with the described
trick leads to algorithms that can easily handle regularization
problems with any number of spatial dimensions (Naghizadeh
and Sacchi, 2009).

I have written an implementation of the function below, which compares to signal.fftconvolve. It is slower on the CPU compared to this function, but I would nonetheless like to see how it performs on the GPU in PyTorch as a forward convolutional layer.

Can someone kindly help me port this code to PyTorch so I can verify how it behaves?

"""
HELIX CONVOLUTION FUNCTION

Shrink:
CROPS THE SIZE OF THE CONVOLVED SIGNAL DOWN TO THE ORIGINAL SIZE OF THE ORIGINAL. 

Pad:
PADS THE DIFFERENCE BETWEEN THE ORIGINAL SHAPE AND THE DESIRED, CONVOLVED SHAPE FOR KERNEL AND SIGNAL.

GetLength:
EXTRACTS THE LENGTH OF THE UNWOUND STRIP OF THE SIGNAL AND KERNEL THAT IS TO BE CONVOLVED.

FFTConvolve:
USES THE NUMPY FFT PACKAGE TO PERFORM FAST FOURIER CONVOLUTION ON THE SIGNALS 

Convolve:
USES HELIX CONVOLUTION ON AN INPUT ARRAY AND KERNEL. 

"""

import numpy as np
from numpy import *
from scipy import signal
import operator
import time


class HelixCPU:
    @classmethod
    def Shrink(cls,array, bounding):
       start = tuple(map(lambda a, da: (a-da)//2, array.shape, bounding))
       end = tuple(map(operator.add, start, bounding))
       slices = tuple(map(slice, start, end))
       return array[slices]

    @classmethod
    def Pad(cls,array, target_shape):
       diff = target_shape-array.shape
       padder=[(0,val) for val in diff]
       padded = np.pad(array, padder, 'constant')
       return padded

    @classmethod
    def GetLength(cls,array_shape, padded_shape):
        temp=1
        steps=np.zeros_like(array_shape)

        for i, entry in enumerate(padded_shape[::-1]):
            if(i==len(padded_shape)-1):
               steps[i]=1
            else:
               temp=entry*temp
               steps[i]=temp

         steps=np.roll(steps, 1)
         steps=steps[::-1]
         ones=np.ones_like(array_shape)
         ones[-1]=0
         out=np.multiply(steps,array_shape - ones)
         length = np.sum(out)
         return length

    @classmethod
    def FFTConvolve(cls, in1, in2, len1, len2):
        s1 = len1
        s2 = len2
        shape = s1 + s2 - 1
        fsize = 2 ** np.ceil(cp.log2(shape)).astype(int) 
        fslice = slice(0, shape)
        conv = np.fft.ifft(np.fft.fft(in1, int(fsize)) * np.fft.fft(in2, int(fsize)))[fslice].copy()
        return conv

    @classmethod
    def Convolve(cls,array, kernel):
        m = array.shape
        n = kernel.shape
        mn = np.add(m, n)
        mn = mn-np.ones_like(mn)
        k_pad=cls.Pad(kernel, mn)
        a_pad=cls.Pad(array, mn)
        length_k = cls.GetLength(kernel.shape, k_pad.shape);
        length_a = cls.GetLength(array.shape, a_pad.shape);
        k_flat = k_pad.flatten()[0:length_k]
        a_flat = a_pad.flatten()[0:length_a]
        conv = cls.FFTConvolve(a_flat, k_flat)
        conv = np.resize(conv,mn)
        conv = cls.Shrink(conv, m)
        return conv



def main():

    array=np.random.rand(25,25,41,51)
    kernel=np.random.rand(10, 10, 10, 10)

    start2 =time.process_time()
    test2 = HelixCPU.Convolve(array, kernel)
    end2=time.process_time()

    start1= time.process_time()
    test1 = signal.fftconvolve(array, kernel, "same")
    end1= time.process_time()

    print ("")
    print ("========================")
    print ("SOME LARGE CONVOLVED RANDOM ARRAYS. ")
    print ("========================")
    print("")
    print ("Random Calorimeter Image of Size {0} Created".format(array.shape))
    print ("Random Kernel of Size {0} Created".format(kernel.shape))
    print("")
    print ("Value\tOriginal\tHelix")
    print ("Time Taken [s]\t{0}\t{1}\t{2}".format( (end1-start1), (end2-start2), (end2-start2)/(end1-start1) ))
    print ("Maximum Value\t{:03.2f}\t{:13.2f}".format( np.max(test1), np.max(test2) ))
    print ("Matrix Norm \t{:03.2f}\t{:13.2f}".format( np.linalg.norm(test1), np.linalg.norm(test2) ))
    print ("All Close?\t{0}".format(np.allclose(test1, test2)))