Parallel sampling with torch.vmap, confusion on the notion of "batched" distributions

Hi all,

I’m trying to port some Jax code to Pytorch, so I’m naturally turning to torch.vmap to avoid rewriting a lot of stuff in a different way. However I already have an issue. I’m building hidden Markov models (or state-space models) and one thing I do a lot is to sample from transition kernels, basically functions that apply a deterministic map + some noise to a given input. I have a problem with the randomness part because I’m confused with how torch distributions work. I wrote a Noise class as follows

import torch.distributions as dists 
class Noise:
    def __init__(self, 
                 noise_dist:dists.Distribution, 
                 noise_fn:Callable):
        self.noise_dist = noise_dist 
        self.noise_fn = noise_fn

    def __call__(self, x, params):
        random_sample = self.noise_dist(**params['dist']).rsample()
        return self.noise_fn(x, random_sample, params['fn'])

Now suppose I want to apply simple additive gaussian noise to a vector of samples of dimension d, so I wrote:

import torch 


AdditiveGaussianNoise = Noise(noise_dist = dists.MultivariateNormal, 
noise_fn = lambda x,y,_: x+y)

noise_params = {'dist':{'loc':torch.zeros(d), 
                            'covariance_matrix':torch.eye(d)}, 'fn':{}}

Suppose I’m now trying to apply that noise to a vector of samples, I do

clean_samples = torch.ones(num_samples,d)
noisy_samples = torch.vmap(Noise.apply, in_dims=(0,None), randomness='different')(clean_samples, noise_params)

I get the following error “RuntimeError: vmap: Cannot ask for different inplace randomness on an unbatched tensor. This will appear like same randomness. If this is necessary for your usage, please file an issue with functorch.”

My questions are the following:

  1. Why is clean_samples considered “unbatched” ? The first dimension is effectively a batching dimension, and the second is the dimension of the sampling space. Is there are particular reshaping I should be doing ?
  2. How could I make it work ?

Thanks in advance!