Need urgent help to implement Patch Match in Pytorch

I am at a hackathon, and I spent 8 hours on trying to implement Patch match for N channels of images.
I know this is very unreasonable, but I am sincerely frustrated, and hope somebody from this community can help me implement.

I want to implement this paper: http://gfx.cs.princeton.edu/pubs/Barnes_2009_PAR/patchmatch.pdf

but instead of 3 channels, I want it across N channels. This is a for a neural network, and I will be more than happy to collaborate with whoever helps me.

I would like to add this as a pytorch extension https://github.com/msracver/Deep-Image-Analogy/blob/master/windows/deep_image_analogy/source/GeneralizedPatchMatch.cuh

Help is greatly appreciated.

As I mentioned on Twitter, I’d start with something based on numpy and adapt/optimize for pytorch. Something like https://github.com/jabooth/patchmatch/blob/master/patchmatch.py for example.
It will be definitely slower than using hand-made kernels, but will definitely be easier to debug.
Good luck!