Hi!
I’m constructing an image dataset using torch.einsum
, and I found a behavior that I could not understand.
The following dummy code reproduces what puzzles me:
import torch
rho = torch.randn(10,4,4)
basis = torch.randn(4,2,64,64)
result = torch.einsum('imn,mjkl,njkl->ijkl',rho,basis,basis)
The variable result
, which has shape (10,2,64,64)
, would be my dataset. There are 10 images, with 2 color channels, with resolution 64x64.
When I started training neural networks with this dataset, I found that the process was around 3x slower compared to when I used other datasets with the same shape and datatype. By investigating further, I found out that result.stride() = (1, 40960, 640, 10)
, which explains the slowness: during the batching process, each batch is not contiguously stored on memory. By correcting this dataset using result.contiguous()
I get back to normal speeds.
I would like, therefore, to understand what determines the stride of the result of einsum
. Is there a better way to obtain the stride that I want?
Thank you very much!