Einsum Memory Efficiency

I have a situation where I want to do a complicated einsum operation across 4 tensors. I am trying to optimize for memory efficiency, and if the underlying code is breaking this up into multiple matmul operations, it produce intermediate matrices that are far bigger than desired. My understanding is that other einsum implementations do use matmul as the underlying operation, but I have not found this info about pytorch. Would this be efficient, and if not is there a clean way to do multiply-accumulate operations across more than 2 matrixes?


Maybe you want to try to use https://github.com/dgasmith/opt_einsum to see if you can get a better formula?

For the exact implementation details, I’m afraid I don’t know :confused:

1 Like

Thanks for linking this! If I am understanding correctly, it looks like this library is using intermediate operations deliberately, which is good for some situations but not the memory-constrained case. It does offer a bit more explanation than the pytorch documentation though

Standard PyTorch einsum reduces to bmm calls in sequential order, so it’s not memory efficient if you have large intermediates.
There isn’t a built-in way, but from the experience with torch.bilinear, you might get away by computing the result “row-wise” instead of in one go.

Best regards