Optimized double ‘for’ pixel loops

Hi All,
I have a layer for the U-Net model; the kernel of the layer must loop through all the pixels in the image (image size is 1024x1024) - it is really slow. It takes 2.15 (minutes on average) for one image per one iteration. Any suggestions on how can I optimize the double loop over the image matrix?

Can you post the code? there might be a way to vectorize it or compose it from faster native operations

def forward(self, details, norms, max_pixels, min_pixels):

    return torch.FloatTensor([[torch.where(dtl_pixel < 0, weight(dtl_pixel, nrm, max_pixel, min_pixel), dtl_pixel) for j, (dtl_pixel, nrm, max_pixel, min_pixel) in
                               enumerate(zip(detail, norm, max_pixel, min_pixel))] for i, (detail, norm, max_pixel, min_pixel) in
                              enumerate(zip(details, norms, max_pixels, min_pixels))]).reshape((1024, 1024))

I organized dataset (with torch.utils.data.Dataset) for details, norms, min pixels, max pixels as separated class, that I calculate in the beginning, and is used by the kernel.
weight function is:
weight = lambda detail, norm, max, min: detail / torch.exp(torch.pow((max * self.lamda) / (norm * min + self.epsilon), 0.01))

It is a bit difficult to reason about the computation without seeing some concrete shapes, but I think the general pattern might be amenable to vectorization as I don’t see any obvious tricky control flow.

Is where being called on just a single pixel? I think it would likely be faster to precompute both branches first and then select them later.
I think this can be done if the weight function can be vectorized across pixels (just make sure that all of the inputs have the same shape or can be broadcast in a way that makes sense).