Model unable to converge after quantization of the linear weights

I’m trying to quantize a simple covnet model’s linear layers (3 bit quantization with a scaling factor for the weights). The model converges to good accuracy numbers on several datasets without the ternarization , but fails to achieve any accuracy after ternarization.

My implementation goes like this - Weights are initialized using the Xavier initialization. The linear layers are ternarized before front propagation takes place. I have also tried implementing gradient estimator methods like STE and relu estimator as regular gradient descent algorithms face the problem of vanishing gradients due to the step like graph of the function after ternarization. I did this by using full backward hooks of the nn.module.

The code for ternarization of weights is as follows

class TernarizeOp():
    def __init__(self, model):
        count_targets = 0
        self.model = model
        for m in model.modules():
            if isinstance(m, nn.Linear):
                count_targets += 1
        count_targets = count_targets
        self.ternarize_range = np.linspace(0, count_targets - 1, count_targets).astype('int').tolist()
        self.num_of_params = len(self.ternarize_range)
        self.saved_params = []
        self.target_modules = []

        for m in model.modules():
            if isinstance(m, nn.Linear):
                tmp = m.weight.data.clone()
                self.saved_params.append(tmp)  # tensor
                self.target_modules.append(m.weight)  # Parameter

        # self.target_modules.pop()
        # self.saved_params.pop()


    def SaveWeights(self):
        for index in range(self.num_of_params):
            self.saved_params[index].copy_(self.target_modules[index].data)

    def TernarizeWeights(self):
        alpha = []
        for index in range(self.num_of_params):
            output,alpha_tmp = self.Ternarize(self.target_modules[index].data)
            self.target_modules[index].data = output
            alpha.append(alpha_tmp)
        return alpha

    def Ternarize(self, tensor):
        tensor = tensor.cuda()
        # print(tensor[0])

        new_tensor = tensor.abs()
        delta = torch.mul(0.75, torch.mean(new_tensor, dim=1))
        # print(delta[0])
        new_tensor = torch.t(new_tensor)

        t = torch.greater_equal(new_tensor,delta).type(torch.cuda.FloatTensor)
        # print(t[0])
        x = torch.greater(tensor,0).type(torch.cuda.FloatTensor)
        y = torch.less(tensor,0).type(torch.cuda.FloatTensor)
        y = torch.mul(y,-1)
        z = torch.add(x,y)
        t = torch.t(t)
        final = torch.mul(t,z)

        new_tensor = torch.t(new_tensor)

        final.cuda()
        alpha = torch.mean(torch.mul(final,new_tensor),dim=1)

        # print(output[0])

        return (final,alpha)

    def Restore(self):
        for index in range(self.num_of_params):
            self.target_modules[index].data.copy_(self.saved_params[index])

Theoretically the model should converge to close to optimal values of weights and good accuracy figures but that does not happen.

What could be the reason?

Update : The gradients still seem to vanish and become 0 after a few epochs. I am using batchnorm and xavier initialization for the weights but it does not seem to work.