If you do the computation on mu_p = mu[:, None, :]
and mu_q = mu[None, :, :]
, similarly for std
and replace the dim=1
by dim=-1
(i.e. the last dimension), you should be good to go thanks to broadcasting. Obviously, it would be even neater if we could avoid materializing [B, B, H] matrices, but that is much harder (and might be done by torch.compile or similar automatically now or in the future).
Best regards
Thomas