Calculating outer subtraction for a batch

I know we can calculate outer subtraction using broadcasting as given in


But how can I do it for batches where first dimension is batch size and then we have second dimension to be the vectors for which we want to calculate outer subtraction between each vector in the first tensor with every other vector in the second tensor.

See https://stackoverflow.com/questions/55739993/pytorch-batch-outer-addition