What determines the stride of the output of einsum?


I’m constructing an image dataset using torch.einsum, and I found a behavior that I could not understand.

The following dummy code reproduces what puzzles me:

import torch

rho = torch.randn(10,4,4)
basis = torch.randn(4,2,64,64)
result = torch.einsum('imn,mjkl,njkl->ijkl',rho,basis,basis)

The variable result, which has shape (10,2,64,64), would be my dataset. There are 10 images, with 2 color channels, with resolution 64x64.

When I started training neural networks with this dataset, I found that the process was around 3x slower compared to when I used other datasets with the same shape and datatype. By investigating further, I found out that result.stride() = (1, 40960, 640, 10), which explains the slowness: during the batching process, each batch is not contiguously stored on memory. By correcting this dataset using result.contiguous() I get back to normal speeds.

I would like, therefore, to understand what determines the stride of the result of einsum. Is there a better way to obtain the stride that I want?

Thank you very much!

The stride result of a computation is an implementation detail. What happens (in the most common case) is that PyTorch rearranges dimensions to express the contraction of 2 operands as a batch matrix multiplication and then permutes the axes back to the desired output.

One easy thing to try is to call .contiguous() right after the einsum.
If you are on the CPU, another thing you could try is to use numpy.einsum, which can be much more sophisticated than PyTorch’s implementation.

Best regards