Fast Walsh-Hadamard transform

Following my previous question, (Fastfood random projection), I am implementing Fast Walsh-Hadamard transform (FWHT) in Pytorch extending autograd Function, in order to be able to propagate gradients through it. This is done very easily: as the Walsh-Hadamard transform is equal to its inverse, the same transformation is applied in the forward and backward pass.

The point is that I found a big time difference when computing the transformation itself operating on Pytorch tensors or on Numpy array (20x slower with tensors). Specifically, the following piece of code operates on tensor:

class FWHT_OLD(Function):
    
    def transform(tensor):
        """ Simple implementation of FWHT, receiving as input a torch Tensor. """
        bit = length = len(tensor)
        result = torch.tensor(tensor)

        for _ in range(int(np.log2(length))): 
            bit >>= 1
            for i in range(length):  # d loops
                if i & bit == 0: 
                    j = i | bit
                    temp = torch.tensor(result[i])  # otherwise it copies by reference!
                    result[i] += result[j]
                    result[j] = temp - result[j]

        result /= np.sqrt(length)  # normalize
        return result

    @staticmethod
    def forward(ctx, input):
        return FWHT_OLD.transform(input) 
    
    @staticmethod
    def backward(ctx, grad_output):
        return FWHT_OLD.transform(grad_output)  
    

While the following does the same operation but converting the tensor to a numpy array (only the lines with a comment changed with respect to previous one):

class FWHT(Function):
    
    def transform(tensor):
        """ Simple implementation of FWHT, receiving as input a torch Tensor. """
        bit = length = len(tensor)
        result = tensor.detach().numpy()  # transform to numpy

        for _ in range(int(np.log2(length))): 
            bit >>= 1
            for i in range(length): 
                if i & bit == 0: 
                    j = i | bit
                    temp = result[i]  # this copies by value
                    result[i] += result[j]
                    result[j] = temp - result[j]

        result /= np.sqrt(length) 
        return torch.from_numpy(result)  # transform back to torch

    @staticmethod
    def forward(ctx, input):
        return FWHT.transform(input) 
    
    @staticmethod
    def backward(ctx, grad_output):
        return FWHT.transform(grad_output)  
    

In the following picture there is a time analysis of the performance of the algorithm in forward and backward pass. The yellow line is the comparison with a pure numpy code.

fwht_time

Does anybody have a clue on why this happens?

DISCLAIMER: My code was inspired from https://gist.github.com/dougalsutherland/1a3c70e57dd1f64010ab

2 Likes

Did you ever figure out the speed issue?