How to threshold each channel in a tensor selectively

I have a tensor A
I want to threshold each of its channel selectively.
meaning that if the value in that channel is more than the mean of the channel set that to one and if it is less than that set that to 0
Here is what i did.

A = torch.rand(1,10,8,8)
Means = torch.mean(torch.mean(A,dim=3),dim=2)
Ones = torch.ones(A.size())
Zeros = torch.zeros(A.size())
Thresholded = torch.where(A > Means, Ones, Zeros)

but it does not work

any suggestions?

Means size should be broadcastable to A

A = torch.rand(1,10,8,8)
Means = torch.mean(torch.mean(A,dim=3),dim=2).unsqueeze(-1).unsqueeze(-1)
Ones = torch.ones(A.size())
Zeros = torch.zeros(A.size())
Thresholded = torch.where(A > Means, Ones, Zeros)

do you know what should i do if I want to use torch.clamp(A,min=Means) instead of torch.where?