Custom binarization layer with straight through estimator gives error

You should directly use torch.sign into the forward of your module, since its backward is already implemented:

class Net(nn.Module):
    def __init__(self):
        [...]

    def forward(self, input):
        x = layer1(input)
        # binary layer step:
        x = torch.sign(x)
        [...]
        return x
1 Like