Speed Optimize of a customized layer that utlizes "torch.meshgrid" and "for loop"

Hi all, I am implementing a customized layer and found out it runs super-slow. I guess it might be due to the way that I used torch.meshgrid" and “for loop” inside the forward which cannot benefit from GPU.

Can anyone help and guide how to further optimize the code to speed up? Thanks.

Here is the layer:


def e2e_abs_max_diff(a, b):

    ##calculate max of the element wise abs difference between two 1D tensor with all pair combinations 
    ## same 1xn for a and b

    a_c, b_c = torch.meshgrid(a.view(-1), b.view(-1))
    # print(a_c.size(), b_c.size())
    v = torch.sub(a_c, b_c)
    # print(v)
    result = torch.max(torch.abs(v))

    return result


class MaxDiffMap(nn.Module):
    
    #input: X, Y 4 D tensors in (N,C,D,H,W)
    #output: result_windows in (N,C,D,H', W')
    #param: win_a, win_b, stride_a, stride_b
    #property:
    # Implementation of maxpooled difference map
  

    def __init__(self, win_a, win_b, stride_a, stride_b):
        super().__init__()
        self.win_a = win_a
        self.win_b = win_b
        self.stride_a = stride_a
        self.stride_b = stride_b

    def forward(self, X, Y):
        ## check the input size of X and Y
        ## assume X Y same size in (N,C,D,H,W)

        XX = X.unfold(3, self.win_a, self.stride_a).unfold(4, self.win_b, self.stride_b)
        XX = XX.contiguous().view(*XX.size()[:-2], -1)
        result_shape = XX.size()

        YY = Y.unfold(3, self.win_a, self.stride_a).unfold(4, self.win_b, self.stride_b)
        YY = YY.contiguous().view(*YY.size()[:-2], -1)

        XX = XX.view(*XX.shape[:5], -1)
        YY = YY.view(*YY.shape[:5], -1)

        XX = XX.view(-1, self.win_a * self.win_b)
        YY = YY.view(-1, self.win_a * self.win_b)



        ## Note: type conversion of tensors.
        result_windows = torch.zeros(XX.shape[0:-1]).type_as(X)
        result_windows = result_windows.view(-1, 1)

        #print(result_windows.size())

        for i in range(0, XX.shape[0]):
            # print(i)
            result_windows[i] = e2e_abs_max_diff(XX.view(-1, self.win_a * self.win_b)[i, :],
                                                 YY.view(-1, self.win_a * self.win_b)[i, :])

        result_windows = result_windows.view(result_shape[0:-1])

        return result_windows

Could you add the missing pieces of code to make the current code snippet executable, please? This would allow us to check if e.g. torch.compile could speed up your code.