# Channelwise Threshold

I am trying to apply a different threshold to each input channel (e.g., similar to PReLU with \alpha set to nChannels), but the thresholds do not need to be learned. So far I have the following, but it is too slow due to the for loop.

class ChannelwiseThreshold(nn.Module):
def __init__(self):
super(ChannelwiseThreshold, self).__init__()

def forward(self, x, thresholds):
h = []
for i in range(len(thresholds)):
h.append(F.threshold(x[:, i], thresholds[i], 0))



Is there any more efficient way of achieving this?

Thanks,

I think the stacking operation in your case is expensive. I checked the PRelu code and it seems to be done in C. But, I think doing the operation in place should make it fast enough.

class ChannelwiseThreshold(nn.Module):
def __init__(self):
super(ChannelwiseThreshold, self).__init__()

def forward(self, x, thresholds):
for i in range(len(thresholds)):
F.threshold_(x[:, i], thresholds[i], 0)

return x


Sorry if I was unclear, but the thresholds do not need to be learned but x may be a Variable. If I try to use this code I get the following error:

RuntimeError: one of the variables needed for the gradient computation has been modified by an inplace operation

I assume this is due to repeated inplace operations on the Variable x.

Also, I tried removing the stack operation (just allocating a new Variable before entering the for loop) and it is still slow (from 45 seconds an epoch to 30 minutes an epoch). For larger layers, (e.g., C = 500), this translates to 500 calls to F.threshold for each forward pass in that layer.

I’m not sure what your shapes are, but could this work?

a = torch.randn(3, 2)
th = torch.randn(3, 1)

print(a)
> tensor([[ 1.1200, -0.1365],
[-0.1328,  0.0842],
[-0.5945,  0.6210]])
print(th)
> tensor([[-0.7215],
[ 1.1406],
[-0.4683]])

idx = (a > th).float()
print(a * idx)
> tensor([[ 1.1200, -0.1365],
[-0.0000,  0.0000],
[-0.0000,  0.6210]])

1 Like

Thanks! I think this could work. Do you know how to do this over say dim=1 in the case of 4 dimensions?

a = torch.randn(3, 2, 4, 4)
th = torch.randn(1, 2, 1, 1)
idx = (a > th).float() #over dim=1 <idx would be shape (3,2,4,4)>

I could always permute the data to make it two dimensions as in your example, but I think that would be a bit slower.

Edit: Disregard this comment – I didn’t realize the semantics already work this way!