Memory efficiency filter with arbitrarily connected kernel

Hi, I’m trying to implement a filter layer with arbitrary kernel connections. For example:

def ArbitrayConnectFilter(x, kernel, connection):
    '''
       x: input tensor with shape (spatial_in, channels),
       kernel: filter kernel with shape (spatial_out, K),
       connection: indices of x that connected to each kernel with shape (spatial_out, K)
    '''
        x = x[connection.flatten()].view(*connection.shape, -1) # (spatial_out, K, channels)
        x = (x.permute(0,2,1) @ kernel.unsqueeze(2)).squeeze(2) # (spatial_out, channels)
    return x

But I figured that the intermediate indexing step might take too much memory if the kernel size K is too large. Is there any way to save memory and do the operation without caching the indexed intermediate tensor? Thanks!