I am trying to apply a different threshold to each input channel (e.g., similar to PReLU with \alpha set to nChannels), but the thresholds do not need to be learned. So far I have the following, but it is too slow due to the for loop.
class ChannelwiseThreshold(nn.Module):
def __init__(self):
super(ChannelwiseThreshold, self).__init__()
def forward(self, x, thresholds):
h = []
for i in range(len(thresholds)):
h.append(F.threshold(x[:, i], thresholds[i], 0))
return torch.stack(h, dim=1)
Is there any more efficient way of achieving this?
Thanks,
Brad
I think the stacking operation in your case is expensive. I checked the PRelu code and it seems to be done in C. But, I think doing the operation in place should make it fast enough.
class ChannelwiseThreshold(nn.Module):
def __init__(self):
super(ChannelwiseThreshold, self).__init__()
def forward(self, x, thresholds):
for i in range(len(thresholds)):
F.threshold_(x[:, i], thresholds[i], 0)
return x
Sorry if I was unclear, but the thresholds do not need to be learned but x may be a Variable. If I try to use this code I get the following error:
RuntimeError: one of the variables needed for the gradient computation has been modified by an inplace operation
I assume this is due to repeated inplace operations on the Variable x.
Also, I tried removing the stack operation (just allocating a new Variable before entering the for loop) and it is still slow (from 45 seconds an epoch to 30 minutes an epoch). For larger layers, (e.g., C = 500), this translates to 500 calls to F.threshold for each forward pass in that layer.
I’m not sure what your shapes are, but could this work?
a = torch.randn(3, 2)
th = torch.randn(3, 1)
print(a)
> tensor([[ 1.1200, -0.1365],
[-0.1328, 0.0842],
[-0.5945, 0.6210]])
print(th)
> tensor([[-0.7215],
[ 1.1406],
[-0.4683]])
idx = (a > th).float()
print(a * idx)
> tensor([[ 1.1200, -0.1365],
[-0.0000, 0.0000],
[-0.0000, 0.6210]])
1 Like
Thanks! I think this could work. Do you know how to do this over say dim=1 in the case of 4 dimensions?
a = torch.randn(3, 2, 4, 4)
th = torch.randn(1, 2, 1, 1)
idx = (a > th).float() #over dim=1 <idx would be shape (3,2,4,4)>
I could always permute the data to make it two dimensions as in your example, but I think that would be a bit slower.
Edit: Disregard this comment – I didn’t realize the semantics already work this way!