Need urgent help to implement Patch Match in Pytorch

I am at a hackathon, and I spent 8 hours on trying to implement Patch match for N channels of images.
I know this is very unreasonable, but I am sincerely frustrated, and hope somebody from this community can help me implement.

I want to implement this paper:

but instead of 3 channels, I want it across N channels. This is a for a neural network, and I will be more than happy to collaborate with whoever helps me.

I would like to add this as a pytorch extension

Help is greatly appreciated.

As I mentioned on Twitter, I’d start with something based on numpy and adapt/optimize for pytorch. Something like for example.
It will be definitely slower than using hand-made kernels, but will definitely be easier to debug.
Good luck!