Multiple Categorical distributions

I’m working with a feature vector of multiple categorical features, with various cardinalities. How can I write my code so that I can log_prob the whole vector at once.

For now it looks like this:

def __init__(self,k):
    self.x1 = torch.randn(k,2,requires_grad=True)
    self.x2 = torch.randn(k,5,requires_grad=True)
    self.x3 = torch.randn(k,3,requires_grad=True)

def forward(self, x):
    D_x1 = D.Categorical(logits=self.x1)
    D_x2 = D.Categorical(logits=self.x2)
    D_x3 = D.Categorical(logits=self.x3)

    return D_x1.log_prob(x[:,0]) + D_x2.log_prob(x[:,1]) + D_x3.log_prob(x[:,2])

I’d use D.MixtureSameFamily but I can figure how to handle various number of logits.

Thanks !