Implementing a Multivariate Normal Distribution in LibTorch

I’m trying to write a Multivariate Normal Distribution with rsample() in LibTorch because LibTorch has no distributions yet.
I’m not sure I have what it takes from the source code. I think I am having a problem with my Multivariate Normal Distribution knowing what to do with a covariance matrix instead of a single number standard deviation.
I’m also having issues recreating torch.distributions.MultivariateNormal in Torch.

The Python problems
This is my attempt to write a Multivariate Normal Distribution in Torch, in Python. I’m not getting errors but I am getting a matrix with the size torch.Size([2, 2]) from this. I get a torch.Size([2]) as output when input the exact same tensors into torch.distributions.MultivariateNormal and take a sample or an rsample.

µ = torch.tensor([-0.5, 2.0])
cov = torch.tensor([[2.0, 8], [2.0, 40]])
pos_cov = cov @ cov.T 
Σ = torch.linalg.cholesky(pos_cov)

ε = torch.randn(1)

torch.det(torch.pi * 2 * Σ)**1/2 * torch.exp(torch.tensor(-1/2) * (ε - µ).T * Σ**-1 * (ε - µ))

Outputs a tensor like

tensor([[1182.5215,    0.0000],
        [1245.9080, 1102.0341]])

MultivariateNormal(µ, pos_cov).sample()
Outputs a tensor like
tensor([ 7.6214, 37.2902])

The LibTorch problems
I don’t see anything about a covarience matrix or Multivariate Normal Distribution in the CUDA DistributionTemplates.cu file

I know my Python implimentation of the Multivariate Normal Distribution doesn’t match the torch.distributions.MultivariateNormal function, but I tried implementing what I have in LibTorch.

#include <torch/torch.h>

class MultivariateNormalx{
    torch::Tensor mean, stddev, var;
public:
    MultivariateNormalx(const torch::Tensor &mean, const torch::Tensor &std) : mean(mean), stddev(std), var(std * std) {}

    torch::Tensor rsample() {
        auto device = torch::cuda::is_available() ? torch::kCUDA : torch::kCPU;
        auto eps = torch::randn(1).to(device);
        auto cholesky = torch::linalg::cholesky(torch::matmul(stddev.transpose(0, 1), stddev));
        auto pi = torch::tensor(3.141592653589793);
        return pow(torch::det(pi * 2 * cholesky), 1/2) * torch::exp(-1/2 * (eps - mean).transpose(0, 1) * pow(cholesky, -1) * (eps - mean));
    }
};

torch::Tensor μ = torch::tensor({{-0.5, 2.0}, {-0.5, 2.0}});
torch::Tensor cov = torch::tensor({{2.0, 8.0}, {2.0, 40.0}});

int main () {

    std::cout << MultivariateNormalx(μ, cov).rsample();

}

Always returns…

 1 -nan
 1  1
[ CPUFloatType{2,2} ]