Softmax along different masks without for loop

Say I have a vector a , with an index vector b of the same length. The indexs are in range 0~9, corresponding to 10 class. How can I do softmax for every class without for loop?

Toy example:

a = torch.rand(10)
a: tensor([0.3376, 0.0557, 0.3016, 0.5550, 0.5814, 0.1306, 0.2697, 0.9989, 0.4917,
        0.6306])
b = torch.randint(0,3,(1,10), dtype=torch.int64)
b: tensor([[1, 2, 0, 2, 2, 0, 1, 1, 1, 1]])

I want to do softmax like

for index in range(3):
    softmax(a[b == index])

but without the for loop to save time.