Reparametrization Trick (Normal distributions)

How torch.normal works underhood?

I am asking this with regard to reparamatrization trick?

import torch
import numpy as np

N = 1
mu_grads = []
std_grads = []


def reparametrize(mu, std):
    eps = torch.rand_like(std)
    return mu + std * eps

mu = torch.tensor([12.], requires_grad=True) # set mu = 12 and store gradient
std = torch.tensor([42.], requires_grad=True) # set std = 42

for i in range(N):
    mu.grad = None # reset the gradient value
    std.grad = None

    z = reparametrize(mu, std)
    z2 = z ** 2
    z2.backward()

    mu_grads.append(mu.grad.detach().cpu().numpy())
    std_grads.append(std.grad.detach().cpu().numpy())

print(f"Estimated dE[z^2]/dmu={np.mean(mu_grads):.2f}")
print(f"Estimated dE[z^2]/dstd={np.mean(std_grads):.2f}")

The above code works fine!
But below code does not?

import torch
import numpy as np

N = 1
mu_grads = []
std_grads = []

mu = torch.tensor([12.], requires_grad=True) # set mu = 12 and store gradient
std = torch.tensor([42.], requires_grad=True) # set std = 42

for i in range(N):
    mu.grad = None # reset the gradient value
    std.grad = None

    z = torch.normal(mu, std) # the random sampling happens here
    # but here also same thing should happen right as reparametrize function.
    # mu + std * torch.randn(1)  <- Is this how torch.normal works?
    z2 = z ** 2
    z2.backward()

    mu_grads.append(mu.grad.detach().cpu().numpy())
    std_grads.append(std.grad.detach().cpu().numpy())

print(f"Estimated dE[z^2]/dmu={np.mean(mu_grads):.2f}")
print(f"Estimated dE[z^2]/dstd={np.mean(std_grads):.2f}")

mu + std * torch.randn(1) ā† Is this how torch.normal works?

Please help!
Thank you!

1 Like

Can anyone please help?
@ptrblck

If Iā€™m not mistaken, torch.normal will call into this normal function which will use the mean and std inputs as plain scalar types and will thus detach them from the tensors, which is why your first approach would work.

Can you please elaborate on why it will detach and on plain scalar types?

Thanks!

I do not know, if this answer your question, but I think you should use rsample from the Normal distribution in pytorch for the reparametrization trick.

Thus:

normal = torch.distributions.normal.Normal(loc = mu, scale = std)

sample = normal.rsample()