Exluding torch.clamp() from backpropagation (as tf.stop_gradient in tensorflow)

Hi,

when using torch.clamp(), the derivative w.r.t. to its input is zero if the input is outside [min, max]. This results in all gradients for previous operations in the graph to become zero due to the chain rule:
20

In tensorflow, one can use tf.stop_gradient (https://www.tensorflow.org/api_docs/python/tf/stop_gradient) to prevent this behavior. Is there something similar for PyTorch?

5 Likes

Below is a minimal working example. A neural net with one weight is supposed to push its input towards 5. To reduce the range of possible values, the output is clipped to [4, 6].

Ideally, the network should still find the optimal value despite the “prior”. But it gets stuck due to zero gradient.

This would not happen, if the derivative of the clamp-function could be excluded from backpropagation.

import torch
import numpy as np
import torch.nn as nn

x = torch.from_numpy(np.array(([1])).astype(np.float32))  # one scalar as input
layer = nn.Linear(1, 1, bias=False)  # neural net with one weight
optimizer = torch.optim.Adam(params=layer.parameters(), lr=1e-3)

for i in range(101):
    w = list(layer.parameters())[0]  # weight before backprop
    y = layer(x)  # y = w * x
    f_y = torch.clamp(y, min=4, max=6)  # f(y) = clip(y)
    loss = torch.abs(f_y - 5)   # absolute error, zero if f(y) = 5

    optimizer.zero_grad()
    loss.backward()
    grad = w.grad

    if (i % 100 == 0) or (i == 0):
        print('iteration {}'.format(i))
        print('w: {:.2f}'.format(w.detach().numpy()[0][0]))
        print('y: {:.2f}'.format(y.detach().numpy()[0]))
        print('f_y: {:.2f}'.format(f_y.detach().numpy()[0]))
        print('loss: {:.2f}'.format(loss.detach().numpy()[0]))
        print('grad: {:.2f}\n'.format(grad.detach().numpy()[0][0]))

    optimizer.step()
iteration 0
w: 0.96
y: 0.96
f_y: 4.00
loss: 1.00
grad: 0.00

iteration 100
w: 0.96
y: 0.96
f_y: 4.00
loss: 1.00
grad: 0.00
2 Likes

It turns out that the problem can be solved by creating a custom Clamp class with custom backward-method. The only remaining issue is that I do not know how to pass the min/max-values as an argument.

import torch
import numpy as np
import torch.nn as nn


class Clamp(torch.autograd.Function):

    @staticmethod
    def forward(ctx, input):
        return input.clamp(min=4, max=6)

    @staticmethod
    def backward(ctx, grad_output):
        return grad_output.clone()


clamp_class = Clamp()

x = torch.from_numpy(np.array(([1])).astype(np.float32))  # one scalar as input
layer = nn.Linear(1, 1, bias=False)  # neural net with one weight
optimizer = torch.optim.Adam(params=layer.parameters(), lr=1e-3)

for i in range(10001):
    w = list(layer.parameters())[0]  # weight before backprop
    y = layer(x)  # y = w * x
    clamp = clamp_class.apply
    f_y = clamp(y)  # f(y) = clip(y)
    loss = torch.abs(f_y - 5)  # absolute error, zero if f(y) = 2

    optimizer.zero_grad()
    loss.backward()
    grad = w.grad

    if (i % 100 == 0) or (i == 0):
        print('iteration {}'.format(i))
        print('w: {:.2f}'.format(w.detach().numpy()[0][0]))
        print('y: {:.2f}'.format(y.detach().numpy()[0]))
        print('f_y: {:.2f}'.format(f_y.detach().numpy()[0]))
        print('loss: {:.2f}'.format(loss.detach().numpy()[0]))
        print('grad: {:.2f}\n'.format(grad.detach().numpy()[0][0]))

    optimizer.step()

The plot below shows output over iterations. The output finally reaches the target value! :slight_smile:

6 Likes

Personally I use torch.sigmoid as clamping function. It is more expensive, but the gradients (almost) never vanish.

I found @always’s solution quite elegant for my own use case and so have solved the problem of the min/max arguments:

from torch.cuda.amp import custom_bwd, custom_fwd


class DifferentiableClamp(torch.autograd.Function):
    """
    In the forward pass this operation behaves like torch.clamp.
    But in the backward pass its gradient is 1 everywhere, as if instead of clamp one had used the identity function.
    """

    @staticmethod
    @custom_fwd
    def forward(ctx, input, min, max):
        return input.clamp(min=min, max=max)

    @staticmethod
    @custom_bwd
    def backward(ctx, grad_output):
        return grad_output.clone(), None, None


def dclamp(input, min, max):
    """
    Like torch.clamp, but with a constant 1-gradient.
    :param input: The input that is to be clamped.
    :param min: The minimum value of the output.
    :param max: The maximum value of the output.
    """
    return DifferentiableClamp.apply(input, min, max)
2 Likes