I have implemented an ND Hartley Pooling function (see this paper for more information) for Pytorch, using cupy
and sigpy
, as shown below.
While this function seems to work fine, is there a way to further optimize this code to make it faster?
I am interested in up to four-dimensional data, so I have tried to use N-dimensionality wherever possible. Sadly, Pytorch does not seem to have a 4-dimensional FFT function, so I have instead used other GPU libraries to perform the FFT on the GPU.
import torch
import torch.nn as nn
from torch.autograd import Function
import math
import operator
import cupy as cp
import sigpy as sp
def _spectral_crop(array, array_shape, bounding_shape):
start = tuple(map(lambda a, da: (a-da)//2, array_shape, bounding_shape))
end = tuple(map(operator.add, start, bounding_shape))
slices = tuple(map(slice, start, end))
return array[slices]
def _spectral_pad(array, array_shape, bounding_shape):
out = cp.zeros(bounding_shape)
start = tuple(map(lambda a, da: (a-da)//2, bounding_shape, array_shape))
end = tuple(map(operator.add, start, array_shape))
slices = tuple(map(slice, start, end))
out[slices] = array
return out
def DiscreteHartleyTransform(input):
N = input.ndim
axes_n = np.arange(2,N)
fft = sp.fft(input, axes=axes_n)
H = fft.real - fft.imag
return H
def CropForward(input, return_shape):
output_shape = np.zeros(input.ndim).astype(int)
output_shape[0] = input.shape[0]
output_shape[1] = input.shape[1]
output_shape[2:] = np.asarray(return_shape).astype(int)
dht = DiscreteHartleyTransform(input)
dht = _spectral_crop(dht, dht.shape, output_shape)
dht = DiscreteHartleyTransform(dht)
return dht
def PadBackward(grad_output, input_shape):
dht = DiscreteHartleyTransform(grad_output)
dht = _spectral_pad(dht, dht.shape, input_shape)
dht = DiscreteHartleyTransform(dht)
return dht
class SpectralPoolingFunction(Function):
@staticmethod
def forward(ctx, input, return_shape):
input = sp.from_pytorch(input)
ctx.input_shape = input.shape
output = CropForward(input, return_shape)
output = sp.to_pytorch(output)
output = output.float()
return output
@staticmethod
def backward(ctx, grad_output):
grad_output = sp.from_pytorch(grad_output)
grad_input = PadBackward(grad_output, ctx.input_shape)
grad_input = sp.to_pytorch(grad_input)
grad_input = grad_input.float()
return grad_input, None, None
class SpectralPoolNd(nn.Module):
def __init__(self, return_shape):
super(SpectralPoolNd, self).__init__()
self.return_shape = return_shape
def forward(self, input):
return SpectralPoolingFunction.apply(input, self.return_shape)