Trainable HardShrink threshold

Hi, how can the behavior of torch.nn.Hardshrink (found here https://pytorch.org/docs/stable/nn.html#torch.nn.Hardshrink) be implemented with a learnable lambd parameter?

We’ve tried wrapping it as nn.Parameter as shown below but it doesn’t seem to work.

class foo(nn.Module):
    def __init__(self, eps = None):
        super(LossySparsifier, self).__init__()
        # initialize eps as trainable parameter
        if eps == None:
            self.eps = nn.Parameter(torch.tensor(0.1))
        else:
            self.eps = nn.Parameter(torch.tensor(eps))
            
        self.eps.requires_grad = True

    def forward(self, x):
        return torch.hardshrink(x, self.eps.data) # can't give self.eps directly here

We’ve also tried implementing it ourselves using clamp function as follows:

class foo(nn.Module):
    def __init__(self, eps = None):
        super(LossySparsifier, self).__init__()
        # initialize eps as trainable parameter
        if eps == None:
            self.eps = nn.Parameter(torch.tensor(0.1))
        else:
            self.eps = nn.Parameter(torch.tensor(eps))
            
        self.eps.requires_grad = True

    def forward(self, x):
        e = torch.abs(self.eps)
        y = (x - e).clamp(min = 0)
        y[y != 0] = y[y != 0] + e
        z = (-x - e).clamp(min = 0)
        z[z != 0] = z[z != 0] + e
        return y - z

But even here, the gradient wrt eps seems to go to 0. Shouldn’t this not happen as autograd has support for indexing? The grad_fn for y after the indexing is done is grad_fn=<IndexPutBackward>. How is this implemented in pytorch?

Any help on how to proceed will be appreciated. Thank you.

Hi,

Autograd has support for indexing but cannot compute gradients when they don’t exist.
In particular, given that lambda is a threshold, if you ask for the gradient wrt lambda, you will get a function that is 0 almost everywhere. So you cannot learn lambda by gradient descent-like methods.

Hi @albanD, do you know any way by which we could set this up correctly? Maybe some sort of manual gradient method

If you have a better formula for your gradient, you can use a custom Function to specify what the backward should compute (see doc here).
But not sure what that formula would be as the gradient is just 0 here.