How to shrink a tensor while removing specific elements

Given I have a tensor with dimensions bxcxh,w, I like to remove the max value in dimention c and then make a new tensor with dimention bxc-1,w,h.
I can do it in loop, but too slow.

Can anyone suggest what I can use instead of the for loop here ?

here is an example:

import torch

mat = torch.rand(1,3,4,4)
b,c,h,w = mat.shape

mat_max = torch.max(mat,dim=1)[0].reshape(h*w,1)
mat_empt =torch.empty((h*w,c-1))

mat_shape = torch.reshape(mat.permute(0,2,3,1),(h*w,c))
A = mat_shape==mat_max
if len(torch.where(torch.sum(A,axis=1)>1)[0]) ==0:
        for indx,k in enumerate(mat_shape):
            mat_empt[indx] =     k[~A[indx]]
# else:
    # need to fix it later
    
new_mat = torch.reshape(mat_empt,(h,w,c-1)).unsqueeze(0).permute(0,3,1,2)