What is a fast method of calculating the difference in intensity between a local image patch and its four neighbors?

Hello! I am trying to implement a computer vision paper using PyTorch. The thing is it is using a custom loss function that involves iterating over all local image patches of specific size and calculating the sum of differences in intensity between that patch and its four neighbors.

I am trying to find an efficient and fast solution since this would be included in the training loop. I searched online for answers and found that PyTorch provides a method called unfold and I used it to get the patches I need to work with but I do not know how to find the neighbors of a certain patch since the spatial structure of the input image is lost when using that method.

To be more concrete, here is the loss function in question.

image

It is from the Zero-Deep Curve Estimation paper for low-light image enhancement.