I’m using the
LowRankMultivariateNormal distribution in order to have a distribution of logits for every pixel of an image.
I have an issue when using this distribution for squared images with a shape that is even and large (>= 512x512).
The following code does not work on a colab notebook when the size is set to 512x512 but does for size 513x513.
I also tried it on different GPUs, with the same results.
It works fine when using CPUs.
import torch from torch.distributions import LowRankMultivariateNormal DEVICE = "cuda" torch.manual_seed(23) for i in range(10): print(i) distrib = LowRankMultivariateNormal( torch.randn(1, 512, 512, 2).to(DEVICE), torch.randn(1, 512, 512, 2, 10).to(DEVICE), torch.randn(1, 512, 512, 2).to(DEVICE).exp() )
(I added the for-loop because sometimes it works for the first iteration, although it was not the case the last few times I ran it.)
I don’t have any clue on how to explain this behavior.
The issue is for shapes that are a power of 2, not even.