I have a tensor A
I want to threshold each of its channel selectively.
meaning that if the value in that channel is more than the mean of the channel set that to one and if it is less than that set that to 0
Here is what i did.
A = torch.rand(1,10,8,8)
Means = torch.mean(torch.mean(A,dim=3),dim=2)
Ones = torch.ones(A.size())
Zeros = torch.zeros(A.size())
Thresholded = torch.where(A > Means, Ones, Zeros)
but it does not work
any suggestions?