Time issues with F.unfold

I’m currently struggling with some computation time issues.
I have an input and a filter like this:

z = torch.from_numpy(np.random.randint(1, 10, (bs, channels, z_w, z_h))).float()
f = torch.from_numpy(np.random.randint(1, 5, (n_fm, channels, filter_w, filter_h))).float()

I now want to extract patches from that input. First I did this manually:

def getPatches(z,  f, p=0, s=1):
    # Retrieve dimensions
    n_c, z_w, z_h = z.shape

    # Retrieve dimensions
    n_f, n_c, f_w, f_h = f.shape

    # Compute the dimensions of the CONV output volume
    o_h = int((z_h + 2 * p - f_h) / s) + 1
    o_w = int((z_w + 2 * p - f_w) / s) + 1
    p = torch.zeros((o_w*o_w, n_c, f_w, f_h))

    for currW in range(o_w):
        for currH in range(o_h):
            # Get the image patch
            x1 = currW * s
            x2 = currW * s + f_w
            y1 = currH * s
            y2 = currH * s + f_h

            # get patches
            p[currW*o_h + currH] = z[:, x1:x2, y1:y2]

    return p

Later in my code I’m using different torch.einsum calls. Each of them get some quantities that were calculate with the patches as input. Imagine s.th. like torch.einsum("kcwhp,kp->pcwh", A, B), where A is based on p (would be to much to share the whole code).

The above einsum calls needs 0.048s to compute the result, based on the manually created patches.

Now I’ve notices that there is already such a utility in the torch.nn.functional package :slight_smile:
p_fold = F.unfold(z, kernel_size=(filter_w, filter_h), stride=(1,1)).transpose(1, 2).reshape(-1, channels, filter_w, filter_h).

I’ve checked it that both methods are returning the same patches. But now torch.einsum("kcwhp,kp->pcwh", A, B) with the quantity A computed based on p_fold needs 1.03s. This is a big difference and I can’t explain it, as the inputs A and B and the result of the einsum call are the same.

Are you aware of any time issues with F.unfold or is there something general about Pytorch and torch.einsum I might not know?

Thanks for helping!

Could you check, if the einsum time would be reduced in the second case, if you manually add a .contiguous() call on the output of unfold? The output of unfold might not be contiguous and I’m not sure, if einsum would call it internally, which could explain the difference in performance.