The following code works for me:
class Clamp(torch.autograd.Function):
@staticmethod
def forward(ctx, input):
return input.clamp(min=0, max=1) # the value in iterative = 2
@staticmethod
def backward(ctx, grad_output):
return grad_output.clone()
clamp_class = Clamp()
and in nn.Module:
self.z = nn.Parameter(torch.tensor(1.0), requires_grad=True)
clamp_class.apply(self.z)