Hi all,
I’m trying to port some Jax code to Pytorch, so I’m naturally turning to torch.vmap
to avoid rewriting a lot of stuff in a different way. However I already have an issue. I’m building hidden Markov models (or state-space models) and one thing I do a lot is to sample from transition kernels, basically functions that apply a deterministic map + some noise to a given input. I have a problem with the randomness part because I’m confused with how torch distributions work. I wrote a Noise
class as follows
import torch.distributions as dists
class Noise:
def __init__(self,
noise_dist:dists.Distribution,
noise_fn:Callable):
self.noise_dist = noise_dist
self.noise_fn = noise_fn
def __call__(self, x, params):
random_sample = self.noise_dist(**params['dist']).rsample()
return self.noise_fn(x, random_sample, params['fn'])
Now suppose I want to apply simple additive gaussian noise to a vector of samples of dimension d
, so I wrote:
import torch
AdditiveGaussianNoise = Noise(noise_dist = dists.MultivariateNormal,
noise_fn = lambda x,y,_: x+y)
noise_params = {'dist':{'loc':torch.zeros(d),
'covariance_matrix':torch.eye(d)}, 'fn':{}}
Suppose I’m now trying to apply that noise to a vector of samples, I do
clean_samples = torch.ones(num_samples,d)
noisy_samples = torch.vmap(Noise.apply, in_dims=(0,None), randomness='different')(clean_samples, noise_params)
I get the following error “RuntimeError: vmap: Cannot ask for different inplace randomness on an unbatched tensor. This will appear like same randomness. If this is necessary for your usage, please file an issue with functorch.”
My questions are the following:
- Why is
clean_samples
considered “unbatched” ? The first dimension is effectively a batching dimension, and the second is the dimension of the sampling space. Is there are particular reshaping I should be doing ? - How could I make it work ?
Thanks in advance!