Hi all,
We have a question that how can we compute PD(z)Q in batch ? z.size() = (batch_size, dim)

Thanks.
Hi all,
We have a question that how can we compute PD(z)Q in batch ? z.size() = (batch_size, dim)

Thanks.
Use torch.diag_embed.
torch.diag_embed(input, offset=0, dim1=-2, dim2=-1) → Tensor
Creates a tensor whose diagonals of certain 2D planes (specified by dim1 and dim2 ) are filled by input . To facilitate creating batched diagonal matrices, the 2D planes formed by the last two dimensions of the returned tensor are chosen by default.