PyTorch equivalent of TF MultivariateNormalDiag?

What is the PyTorch equivalent of TensorFlow’s MultivariateNormalDiag distribution? Specifically, I have a B x N x D mean tensor and B x N x D variance tensor where B is batch size, N is number of data points, D is the dimension of each data point. I want to create a multi-variate normal distribution with diagonal covariance from these tensors. How can this be implemented ?

2 Likes

Hi,

Please see this issue https://github.com/pytorch/pytorch/pull/11178

Bests

Does that mean that

import tensorflow as tf
import torch
D = torch.distributions

tf.MultivariateNormalDiag(...) == D.Independent(D.Normal(...))

Also from computational efficiency perspective?