Loss stuck at quantization aware training for 16bits

Dear all,

I am training a very tiny transformer model on pytorch. training with the default settings work fine for me. the training loss and validation loss are decreasing gradually. but why I try to appy fake quantization to all of the linear layers, both of the losses stop decreasing, but stay at a constant from the beginning.

May I ask what have gone wrong please?

below is the forward function for the model

    def forward(self, x, targets = None, quantize = False, normalize = True):
        batch_size, sequence_length = x.shape
        # Get token and position embeddings
        tok_emb = self.token_embedding_table(x)
        pos_emb = self.position_embedding_table(torch.arange(sequence_length, device=DEVICE))
        x = tok_emb + pos_emb
        # apply normalization
        if normalize:
            x = self.ln1(x)
        # --- Enter Attention Block   ---
        # Create Q, K, V matrices with weightings
        Q = self.query(x)
        K = self.key(x)
        V = self.value(x)
        attention = Q @ K.transpose(-2, -1) / self.scale
        # Apply softmax
        attention_weights = F.softmax(attention, dim=-1)
        weighted_value = attention_weights @ V
        x = x + weighted_value
        # --- Exit Attention Block    ---
        # --- Enter Feedforward Block ---
        if normalize:
            x = self.ln2(x)
        x_ff_linear1 = self.feed_forward_Linear1(x)
        x_ff_relu = self.feed_forward_relu(x_ff_linear1)
        x_ff_linear2 = self.feed_forward_Linear2(x_ff_relu)
        x = x + x_ff_linear2
        # --- Exit Feedforward Block  ---
        if normalize:
            x = self.ln_f(x)
        logits = self.lm_head(x)
        # calculate loss
        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(batch_size * sequence_length, VOCAB_SIZE)
            targets = targets.view(batch_size * sequence_length)
            loss = F.cross_entropy(logits, targets)
        # Return logits and loss
        return logits, loss

and here is how I try to quantize the model

        # quantization for linear layers
        self.quant_16_linear = quant.QConfig(
            activation = quant.FakeQuantize.with_args(
                observer = quant.MinMaxObserver,
                quant_min = - 32768,
                quant_max = 32767,
                dtype = torch.qint32,
                qscheme = torch.per_tensor_affine
            ),
            weight = quant.FakeQuantize.with_args(
                observer = quant.MinMaxObserver,
                quant_min = - 32768,
                quant_max = 32767,
                dtype = torch.qint32,
                qscheme = torch.per_tensor_affine
            )
        )
        if quantize:
            self.query.qconfig = self.quant_16_linear
            self.key.qconfig = self.quant_16_linear
            self.value.qconfig = self.quant_16_linear
            self.feed_forward_Linear1.qconfig = self.quant_16_linear
            self.feed_forward_Linear2.qconfig = self.quant_16_linear
            self.lm_head.qconfig = self.quant_16_linear

I think it might be hard for us to provide and advice, maybe try increase the learning rate? or tune some other hyper parameters?