Matrix transform for TransformedDistribution?


I’m trying to make use of TransformedDistribution to handle situations such as using the reparameterization trick to backpropagate through sampling from a multivariate normal, but I am noticing a problem. There doesn’t seem to be any Transform in torch.distributions that matches just a simple matrix multiplication of the form y = Ax + b. I thought AffineTransform would work, but it appears to be just a pointwise multiplication. Is there really no native implementation for this? I would have thought this would be the primary use case of TransformedDistribution. Do I need to just implement my own matrix transform?