Commputing the difference among one channel of a tensor and all other channels within that tensor

How can I compute the difference among one channel of a tensor and all other channels within that tensor.
In other words, what is the best way to do this::
B,C,W,H = 2,3,20,20
A = torch.rand(B,C,W,H)
I want to have tensor D which is defined as following:

D[0,0,:,:] = torch.max(A[0:1,0:1,:,:]-A[0:1,:,:,:],dim=1)[0]
D[0,1,:,:] = torch.max(A[0:1,1:2,:,:]-A[0:1,:,:,:],dim=1)[0]
D[0,2,:,:] = torch.max(A[0:1,2:3,:,:]-A[0:1,:,:,:],dim=1)[0]
D[1,0,:,:] = ...
D[1,1,:,:] = ...
D[1,2,:,:] =...

Now imagine in real case C and B are so large.
What is the fastest way to compute D?

The answer is something like this:

A_1 = torch.unsqueeze(A,1)
A_2 = torch.unsqueeze(A,2)
D = torch.max(A_1-A_2,dim = 2)[0]