# Optimizing Custom Hartley Pooling Layer

I have implemented an ND Hartley Pooling function (see this paper for more information) for Pytorch, using `cupy` and `sigpy` , as shown below.

While this function seems to work fine, is there a way to further optimize this code to make it faster?

I am interested in up to four-dimensional data, so I have tried to use N-dimensionality wherever possible. Sadly, Pytorch does not seem to have a 4-dimensional FFT function, so I have instead used other GPU libraries to perform the FFT on the GPU.

``````import torch
import torch.nn as nn
from torch.autograd import Function
import math
import operator
import cupy as cp
import sigpy as sp

def _spectral_crop(array, array_shape, bounding_shape):
start = tuple(map(lambda a, da: (a-da)//2, array_shape, bounding_shape))
end = tuple(map(operator.add, start, bounding_shape))
slices = tuple(map(slice, start, end))
return array[slices]

def _spectral_pad(array, array_shape, bounding_shape):
out = cp.zeros(bounding_shape)
start = tuple(map(lambda a, da: (a-da)//2, bounding_shape, array_shape))
end = tuple(map(operator.add, start, array_shape))
slices = tuple(map(slice, start, end))
out[slices] = array
return out

def DiscreteHartleyTransform(input):
N = input.ndim
axes_n = np.arange(2,N)
fft = sp.fft(input, axes=axes_n)
H = fft.real - fft.imag
return H

def CropForward(input, return_shape):

output_shape = np.zeros(input.ndim).astype(int)
output_shape[0] = input.shape[0]
output_shape[1] = input.shape[1]
output_shape[2:] = np.asarray(return_shape).astype(int)

dht = DiscreteHartleyTransform(input)
dht = _spectral_crop(dht, dht.shape, output_shape)
dht = DiscreteHartleyTransform(dht)

return dht

dht = _spectral_pad(dht, dht.shape, input_shape)
dht = DiscreteHartleyTransform(dht)

return dht

class SpectralPoolingFunction(Function):
@staticmethod
def forward(ctx, input, return_shape):
input = sp.from_pytorch(input)
ctx.input_shape = input.shape
output = CropForward(input, return_shape)
output = sp.to_pytorch(output)
output = output.float()
return output

@staticmethod