when using torch.clamp(), the derivative w.r.t. to its input is zero if the input is outside [min, max]. This results in all gradients for previous operations in the graph to become zero due to the chain rule:
Below is a minimal working example. A neural net with one weight is supposed to push its input towards 5. To reduce the range of possible values, the output is clipped to [4, 6].
Ideally, the network should still find the optimal value despite the “prior”. But it gets stuck due to zero gradient.
This would not happen, if the derivative of the clamp-function could be excluded from backpropagation.
import torch
import numpy as np
import torch.nn as nn
x = torch.from_numpy(np.array(([1])).astype(np.float32)) # one scalar as input
layer = nn.Linear(1, 1, bias=False) # neural net with one weight
optimizer = torch.optim.Adam(params=layer.parameters(), lr=1e-3)
for i in range(101):
w = list(layer.parameters())[0] # weight before backprop
y = layer(x) # y = w * x
f_y = torch.clamp(y, min=4, max=6) # f(y) = clip(y)
loss = torch.abs(f_y - 5) # absolute error, zero if f(y) = 5
optimizer.zero_grad()
loss.backward()
grad = w.grad
if (i % 100 == 0) or (i == 0):
print('iteration {}'.format(i))
print('w: {:.2f}'.format(w.detach().numpy()[0][0]))
print('y: {:.2f}'.format(y.detach().numpy()[0]))
print('f_y: {:.2f}'.format(f_y.detach().numpy()[0]))
print('loss: {:.2f}'.format(loss.detach().numpy()[0]))
print('grad: {:.2f}\n'.format(grad.detach().numpy()[0][0]))
optimizer.step()
It turns out that the problem can be solved by creating a custom Clamp class with custom backward-method. The only remaining issue is that I do not know how to pass the min/max-values as an argument.
import torch
import numpy as np
import torch.nn as nn
class Clamp(torch.autograd.Function):
@staticmethod
def forward(ctx, input):
return input.clamp(min=4, max=6)
@staticmethod
def backward(ctx, grad_output):
return grad_output.clone()
clamp_class = Clamp()
x = torch.from_numpy(np.array(([1])).astype(np.float32)) # one scalar as input
layer = nn.Linear(1, 1, bias=False) # neural net with one weight
optimizer = torch.optim.Adam(params=layer.parameters(), lr=1e-3)
for i in range(10001):
w = list(layer.parameters())[0] # weight before backprop
y = layer(x) # y = w * x
clamp = clamp_class.apply
f_y = clamp(y) # f(y) = clip(y)
loss = torch.abs(f_y - 5) # absolute error, zero if f(y) = 2
optimizer.zero_grad()
loss.backward()
grad = w.grad
if (i % 100 == 0) or (i == 0):
print('iteration {}'.format(i))
print('w: {:.2f}'.format(w.detach().numpy()[0][0]))
print('y: {:.2f}'.format(y.detach().numpy()[0]))
print('f_y: {:.2f}'.format(f_y.detach().numpy()[0]))
print('loss: {:.2f}'.format(loss.detach().numpy()[0]))
print('grad: {:.2f}\n'.format(grad.detach().numpy()[0][0]))
optimizer.step()
The plot below shows output over iterations. The output finally reaches the target value!
I found @always’s solution quite elegant for my own use case and so have solved the problem of the min/max arguments:
from torch.cuda.amp import custom_bwd, custom_fwd
class DifferentiableClamp(torch.autograd.Function):
"""
In the forward pass this operation behaves like torch.clamp.
But in the backward pass its gradient is 1 everywhere, as if instead of clamp one had used the identity function.
"""
@staticmethod
@custom_fwd
def forward(ctx, input, min, max):
return input.clamp(min=min, max=max)
@staticmethod
@custom_bwd
def backward(ctx, grad_output):
return grad_output.clone(), None, None
def dclamp(input, min, max):
"""
Like torch.clamp, but with a constant 1-gradient.
:param input: The input that is to be clamped.
:param min: The minimum value of the output.
:param max: The maximum value of the output.
"""
return DifferentiableClamp.apply(input, min, max)