Thank you for answering. I think you are talking about something different.
I have modules that return torch.distribution.Distribution
. I do not wish to know which distribution, I just know it is always the same (e.g. always Normal, or always Categorical). I have a video stream and each of the N
images image in the video stream is input into this module in a for loop, since the architecture is RNN-like. Then after I want calculate the log_prob for each distribution I got, so N
times. Usually when I get tensors back instead of distributions, I would just stack the tensor and then do some processing. However, with distributions I do not see a way to stack them in the batch_dimension, so I have to loop over them again.