Binary Activation Function with Pytorch

I have a 2-layers fully connected network. I would like to convert the output of the first layer to binary. This means that I would like to have a binary-step activation function in the forward paths and Relu activation function in the backward pass. How can I implement this? Any idea would be appreciated.

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(28 * 28, 64)
        self.fc2 = nn.Linear(64, 10)
        nn.init.xavier_normal_(self.fc1.weight)
        nn.init.xavier_normal_(self.fc2.weight)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        x = F.log_softmax(x, dim=1)
        return x

net = Net()

This example could probably do, what you need.

Thank you for your point. I rewrote my code according to the example:

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(28 * 28, 64)
        self.fc2 = nn.Linear(64, 10)
        nn.init.xavier_normal_(self.fc1.weight)
        nn.init.xavier_normal_(self.fc2.weight)

    def forward(self, x):
        x1 = F.relu(self.fc1(x))
        x_backward = x1
        x1[x1 <= 0] = 0
        x1[x1> 0] = 1
        x_forward = x1
        y1 = x_backward + (x_forward - x_backward).detach()
        y2 = self.fc2(y1)
        y3 = F.log_softmax(y2, dim=1)
        return y3  

But it does not work and it returns binary values during both forward and backward phases.

I also tried to define a new class and define the forward and backward paths separately…

I tried to do it like this but it gives me an error (AttributeError: ‘Binary_AF’ object has no attribute ‘dim’ ).

Do you think it makes sense to write the code like this? if yes, how to take care of the error and if no, I would appreciate if you have any suggestion.

class Binary_AF:
    def __init__(self, x):
        self.x = x

    def forward(self):
        self.x[self.x <= 0] = 0
        self.x[self.x > 0] = 1
        return self.x

    def backward(self):
        return self.x


class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(28 * 28, 64)
        self.fc2 = nn.Linear(64, 10)
        nn.init.xavier_normal_(self.fc1.weight)
        nn.init.xavier_normal_(self.fc2.weight)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        y = Binary_AF(x)
        y = self.fc2(y)
        y = F.log_softmax(y, dim=1)
        return y

Hi, I think there is a simper way by combining ‘sign’ and ‘relu’:

# test input
input = torch.tensor([-20, 11], requires_grad=True).float()

# binary
sign = torch.sign(input)
binary_out = torch.relu(sign)
print(binary_out) # 0., 1.

Theoretically, sign (or unit-step) is not differentiable, but it can be backpropagated in PyTorch.
I have no idea how they design it…

Here is a question about this:

For this, for future visitors, I can recommend this blogpost explaining the Straight-Through Estimator (Bengio et al.).

For a quick fix, this module (directly copied from the blog post linked above) should be able to handle a binary activation function.

import torch
import torch.nn as nn
import torch.nn.functional as F

class STEFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input):
        return (input > 0).float()

    @staticmethod
    def backward(ctx, grad_output):
        return F.hardtanh(grad_output)

class StraightThroughEstimator(nn.Module):
    def __init__(self):
        super(StraightThroughEstimator, self).__init__()

    def forward(self, x):
        x = STEFunction.apply(x)
        return x

x = torch.randn(30,30)
print(x)
estimator = StraightThroughEstimator()
estimator(x)
# tensor([[-1.6039e+00, -3.1636e-01, -1.4391e+00,  8.4508e-02, -2.6198e+00, ...
# tensor([[0., 0., 0., 1., 0., 1., 0., 1., 0., 1., 1., 0., 1., 0., 1., 0., 0., 0., ...
2 Likes