Channelwise Threshold

I am trying to apply a different threshold to each input channel (e.g., similar to PReLU with \alpha set to nChannels), but the thresholds do not need to be learned. So far I have the following, but it is too slow due to the for loop.

class ChannelwiseThreshold(nn.Module):
    def __init__(self):
        super(ChannelwiseThreshold, self).__init__()
    
    def forward(self, x, thresholds):
        h = []
        for i in range(len(thresholds)):
            h.append(F.threshold(x[:, i], thresholds[i], 0))
        
        return torch.stack(h, dim=1)

Is there any more efficient way of achieving this?

Thanks,
Brad

I think the stacking operation in your case is expensive. I checked the PRelu code and it seems to be done in C. But, I think doing the operation in place should make it fast enough.

class ChannelwiseThreshold(nn.Module):
    def __init__(self):
        super(ChannelwiseThreshold, self).__init__()
    
    def forward(self, x, thresholds):
        for i in range(len(thresholds)):
            F.threshold_(x[:, i], thresholds[i], 0)
        
        return x

Sorry if I was unclear, but the thresholds do not need to be learned but x may be a Variable. If I try to use this code I get the following error:

RuntimeError: one of the variables needed for the gradient computation has been modified by an inplace operation

I assume this is due to repeated inplace operations on the Variable x.

Also, I tried removing the stack operation (just allocating a new Variable before entering the for loop) and it is still slow (from 45 seconds an epoch to 30 minutes an epoch). For larger layers, (e.g., C = 500), this translates to 500 calls to F.threshold for each forward pass in that layer.

I’m not sure what your shapes are, but could this work?

a = torch.randn(3, 2)
th = torch.randn(3, 1)

print(a)
> tensor([[ 1.1200, -0.1365],
        [-0.1328,  0.0842],
        [-0.5945,  0.6210]])
print(th)
> tensor([[-0.7215],
        [ 1.1406],
        [-0.4683]])

idx = (a > th).float()
print(a * idx)
> tensor([[ 1.1200, -0.1365],
        [-0.0000,  0.0000],
        [-0.0000,  0.6210]])
1 Like

Thanks! I think this could work. Do you know how to do this over say dim=1 in the case of 4 dimensions?

a = torch.randn(3, 2, 4, 4)
th = torch.randn(1, 2, 1, 1)
idx = (a > th).float() #over dim=1 <idx would be shape (3,2,4,4)>

I could always permute the data to make it two dimensions as in your example, but I think that would be a bit slower.

Edit: Disregard this comment – I didn’t realize the semantics already work this way!