Regarding clamped learnable parameter

The following code works for me:

class Clamp(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input):
        return input.clamp(min=0, max=1) # the value in iterative = 2

    @staticmethod
    def backward(ctx, grad_output):
        return grad_output.clone()

clamp_class = Clamp()

and in nn.Module:

self.z = nn.Parameter(torch.tensor(1.0), requires_grad=True)
clamp_class.apply(self.z)
1 Like