LowRankMultivariateNormal throws CUDA error

Hi there!

I’m using the LowRankMultivariateNormal distribution in order to have a distribution of logits for every pixel of an image.
I have an issue when using this distribution for squared images with a shape that is even and large (>= 512x512).

The following code does not work on a colab notebook when the size is set to 512x512 but does for size 513x513.
I also tried it on different GPUs, with the same results.
It works fine when using CPUs.

import torch
from torch.distributions import LowRankMultivariateNormal

DEVICE = "cuda"

torch.manual_seed(23)
for i in range(10):
    print(i)
    distrib = LowRankMultivariateNormal(
        torch.randn(1, 512, 512, 2).to(DEVICE),
        torch.randn(1, 512, 512, 2, 10).to(DEVICE),
        torch.randn(1, 512, 512, 2).to(DEVICE).exp()
    )

(I added the for-loop because sometimes it works for the first iteration, although it was not the case the last few times I ran it.)

I don’t have any clue on how to explain this behavior.

[edit]
The issue is for shapes that are a power of 2, not even.

Thanks for reporting the issue! :slight_smile:
I could reproduce it with the latest nightly binary and have created this issue to track it.

It seems, that the magma_spotrf_batched call creates the illegal memory access.

Thanks for the response.
Do you think that there could be a workaround in the meantime?

Unsure at the moment, as we are using MAGMA for this particular operation, which causes the failure.
Let’s see, if the code owners might have an idea for a workaround.