What to do for gradients when rsample is not implemented in a distribution?


I’m currently looking at the MixtureSameFamily distribution class. I’m trying to have a Gaussian Mixture Model that I can sample (using rsample), use the samples to calculate a loss, and then do a backward pass based on that loss. rsample, however is currently not implemented for MixtureSameFamily.

I have no trouble getting it to work for the case where the distribution is just MultiVariateNormal since this does have a rsample method.

I understand that for some reason it is problematic to implement for the mixture class (I also saw the same problem in tensorflow’s mixture class). That being said, is there any other way within torch or outside it that I can collect samples from my GMM and have a loss defined by them which I can make a backward call over?

Any suggestions would be highly appreciated! Thank You!