Hi, I’m trying to implement a filter layer with arbitrary kernel connections. For example:
def ArbitrayConnectFilter(x, kernel, connection):
'''
x: input tensor with shape (spatial_in, channels),
kernel: filter kernel with shape (spatial_out, K),
connection: indices of x that connected to each kernel with shape (spatial_out, K)
'''
x = x[connection.flatten()].view(*connection.shape, -1) # (spatial_out, K, channels)
x = (x.permute(0,2,1) @ kernel.unsqueeze(2)).squeeze(2) # (spatial_out, channels)
return x
But I figured that the intermediate indexing step might take too much memory if the kernel size K is too large. Is there any way to save memory and do the operation without caching the indexed intermediate tensor? Thanks!