# Exluding torch.clamp() from backpropagation (as tf.stop_gradient in tensorflow)

Hi,

when using torch.clamp(), the derivative w.r.t. to its input is zero if the input is outside [min, max]. This results in all gradients for previous operations in the graph to become zero due to the chain rule:

In tensorflow, one can use tf.stop_gradient (https://www.tensorflow.org/api_docs/python/tf/stop_gradient) to prevent this behavior. Is there something similar for PyTorch?

5 Likes

Below is a minimal working example. A neural net with one weight is supposed to push its input towards 5. To reduce the range of possible values, the output is clipped to [4, 6].

Ideally, the network should still find the optimal value despite the â€śpriorâ€ť. But it gets stuck due to zero gradient.

This would not happen, if the derivative of the clamp-function could be excluded from backpropagation.

``````import torch
import numpy as np
import torch.nn as nn

x = torch.from_numpy(np.array(([1])).astype(np.float32))  # one scalar as input
layer = nn.Linear(1, 1, bias=False)  # neural net with one weight

for i in range(101):
w = list(layer.parameters())[0]  # weight before backprop
y = layer(x)  # y = w * x
f_y = torch.clamp(y, min=4, max=6)  # f(y) = clip(y)
loss = torch.abs(f_y - 5)   # absolute error, zero if f(y) = 5

loss.backward()

if (i % 100 == 0) or (i == 0):
print('iteration {}'.format(i))
print('w: {:.2f}'.format(w.detach().numpy()[0][0]))
print('y: {:.2f}'.format(y.detach().numpy()[0]))
print('f_y: {:.2f}'.format(f_y.detach().numpy()[0]))
print('loss: {:.2f}'.format(loss.detach().numpy()[0]))

optimizer.step()
``````
``````iteration 0
w: 0.96
y: 0.96
f_y: 4.00
loss: 1.00

iteration 100
w: 0.96
y: 0.96
f_y: 4.00
loss: 1.00
``````
2 Likes

It turns out that the problem can be solved by creating a custom Clamp class with custom backward-method. The only remaining issue is that I do not know how to pass the min/max-values as an argument.

``````import torch
import numpy as np
import torch.nn as nn

@staticmethod
def forward(ctx, input):
return input.clamp(min=4, max=6)

@staticmethod

clamp_class = Clamp()

x = torch.from_numpy(np.array(([1])).astype(np.float32))  # one scalar as input
layer = nn.Linear(1, 1, bias=False)  # neural net with one weight

for i in range(10001):
w = list(layer.parameters())[0]  # weight before backprop
y = layer(x)  # y = w * x
clamp = clamp_class.apply
f_y = clamp(y)  # f(y) = clip(y)
loss = torch.abs(f_y - 5)  # absolute error, zero if f(y) = 2

loss.backward()

if (i % 100 == 0) or (i == 0):
print('iteration {}'.format(i))
print('w: {:.2f}'.format(w.detach().numpy()[0][0]))
print('y: {:.2f}'.format(y.detach().numpy()[0]))
print('f_y: {:.2f}'.format(f_y.detach().numpy()[0]))
print('loss: {:.2f}'.format(loss.detach().numpy()[0]))

optimizer.step()
``````

The plot below shows output over iterations. The output finally reaches the target value!

6 Likes

Personally I use `torch.sigmoid` as clamping function. It is more expensive, but the gradients (almost) never vanish.

I found @alwaysâ€™s solution quite elegant for my own use case and so have solved the problem of the min/max arguments:

``````from torch.cuda.amp import custom_bwd, custom_fwd

"""
In the forward pass this operation behaves like torch.clamp.
But in the backward pass its gradient is 1 everywhere, as if instead of clamp one had used the identity function.
"""

@staticmethod
@custom_fwd
def forward(ctx, input, min, max):
return input.clamp(min=min, max=max)

@staticmethod
@custom_bwd