Efficient pseudo-inverse for PyTorch 2D convolution

Background:

Thanks for your attention! I am learning the basic knowledge of 2D convolution, linear algebra and PyTorch. I encounter the implementation problem about the psedo-inverse of the convolution operator. Specifically, I have no idea about how to implement it in an efficient way. Please see the following problem statements for details. Any help/tip/suggestion is welcomed.

(Thanks a lot for your attention!)


The Original Problem:

I have an image feature x with shape [b,c,h,w] and a 3x3 convolutional kernel K with shape [c,c,3,3]. There is y = K * x. How to implement the corresponding pseudo-inverse on y in an efficient way?

There is [y = K * x = Ax], how to implement [x_hat = (A^+)y]?

I guess that there should be some operations using torch.fft. However, I still have no idea about how to implement it. I do not know if there exists an implementation previously.

import torch
import torch.nn.functional as F

c = 32
K = torch.randn(c, c, 3, 3)
x = torch.randn(1, c, 128, 128)
y = F.conv2d(x, K, padding=1)

print(y.shape)

# How to implement pseudo-inverse for y = K * x in an efficient way?

Some of My Efforts:

I may know that the 2D convolution is a linear operator. It is equivalent to a “matrix product” operator. We can actually write out the matrix form of the convolution and calculate its psedo-inverse. However, I think this type of operation will be inefficient. And I have no idea about how to implement it in an efficient way.

According to Wikipedia, the psedo-inverse may satisfy the property of A(A_pinv(x))=x, where A is the convolutional operator, A_pinv is its psedo-inverse, and x may be any image feature.

(Thanks again for reading such a long post!)

The thing here is that the inverse of a sparse matrix (like convolution) is not, in general, sparse / similar in structure, so if you were to derive the inverse by looking at the matrix and inverting that, you would not get a proper inverse.

1 Like

Thank you so much, @tom ! :smile: I read your provided reference. It seems that the Wiener filtering is useful for me. Is there a way to efficiently implement it by PyTorch for a [b,c,h,w] feature and a [c,c,k,k] convolutional kernel? :thinking: