Extracting sliding window patches using indices rather than unfold


I’m working on a problem where I am trying to perform convolution on non-uniform pixel arrangements. Because the pixels have no canonical ordering, the input is of shape (Batch, Channels, N_pixels). I construct a KNN graph to determine the indices of the input that each output neuron is connected to which I used to extract ‘sliding window’ patches. I’ve attached an example module that implements the operation.

class GraphConv(nn.Module):
    def __init__(self, in_channels, out_channels, k, in_size):
        #idx generated from knn_graph
        idx = torch.randint(0, in_size, (in_size, k))
        weights = torch.rand((in_channels, out_channels, in_size, k))
        self.weights = nn.Parameter(weights, True)
        self.register_buffer('idx', idx)
        self.bias = nn.Parameter(torch.zeros(1, out_channels, 1))

    def forward(self, x):
        x = x[:,:,self.idx]
        #b, i, o, n, k - batch, in_channels, out_channels, n_pixels, k
        x = torch.einsum('bink,ionk->bon', x, self.weights)
        x = x + self.bias
        return x

If I am understanding correctly, I can’t use the unfold operation to extract sliding windows of the input due to the random ordering of the input pixels. While this works it is ~10x slower than Conv1D. Is there any way to optimise this to run faster. The pixel arrangement of the input is known beforehand and is the same for all input images.

Thanks in Advance