Hi there!
I’m using the LowRankMultivariateNormal
distribution in order to have a distribution of logits for every pixel of an image.
I have an issue when using this distribution for squared images with a shape that is even and large (>= 512x512).
The following code does not work on a colab notebook when the size is set to 512x512 but does for size 513x513.
I also tried it on different GPUs, with the same results.
It works fine when using CPUs.
import torch
from torch.distributions import LowRankMultivariateNormal
DEVICE = "cuda"
torch.manual_seed(23)
for i in range(10):
print(i)
distrib = LowRankMultivariateNormal(
torch.randn(1, 512, 512, 2).to(DEVICE),
torch.randn(1, 512, 512, 2, 10).to(DEVICE),
torch.randn(1, 512, 512, 2).to(DEVICE).exp()
)
(I added the for-loop because sometimes it works for the first iteration, although it was not the case the last few times I ran it.)
I don’t have any clue on how to explain this behavior.
[edit]
The issue is for shapes that are a power of 2, not even.