Dear all,
I am training a very tiny transformer model on pytorch. training with the default settings work fine for me. the training loss and validation loss are decreasing gradually. but why I try to appy fake quantization to all of the linear layers, both of the losses stop decreasing, but stay at a constant from the beginning.
May I ask what have gone wrong please?
below is the forward function for the model
def forward(self, x, targets = None, quantize = False, normalize = True):
batch_size, sequence_length = x.shape
# Get token and position embeddings
tok_emb = self.token_embedding_table(x)
pos_emb = self.position_embedding_table(torch.arange(sequence_length, device=DEVICE))
x = tok_emb + pos_emb
# apply normalization
if normalize:
x = self.ln1(x)
# --- Enter Attention Block ---
# Create Q, K, V matrices with weightings
Q = self.query(x)
K = self.key(x)
V = self.value(x)
attention = Q @ K.transpose(-2, -1) / self.scale
# Apply softmax
attention_weights = F.softmax(attention, dim=-1)
weighted_value = attention_weights @ V
x = x + weighted_value
# --- Exit Attention Block ---
# --- Enter Feedforward Block ---
if normalize:
x = self.ln2(x)
x_ff_linear1 = self.feed_forward_Linear1(x)
x_ff_relu = self.feed_forward_relu(x_ff_linear1)
x_ff_linear2 = self.feed_forward_Linear2(x_ff_relu)
x = x + x_ff_linear2
# --- Exit Feedforward Block ---
if normalize:
x = self.ln_f(x)
logits = self.lm_head(x)
# calculate loss
if targets is None:
loss = None
else:
B, T, C = logits.shape
logits = logits.view(batch_size * sequence_length, VOCAB_SIZE)
targets = targets.view(batch_size * sequence_length)
loss = F.cross_entropy(logits, targets)
# Return logits and loss
return logits, loss
and here is how I try to quantize the model
# quantization for linear layers
self.quant_16_linear = quant.QConfig(
activation = quant.FakeQuantize.with_args(
observer = quant.MinMaxObserver,
quant_min = - 32768,
quant_max = 32767,
dtype = torch.qint32,
qscheme = torch.per_tensor_affine
),
weight = quant.FakeQuantize.with_args(
observer = quant.MinMaxObserver,
quant_min = - 32768,
quant_max = 32767,
dtype = torch.qint32,
qscheme = torch.per_tensor_affine
)
)
if quantize:
self.query.qconfig = self.quant_16_linear
self.key.qconfig = self.quant_16_linear
self.value.qconfig = self.quant_16_linear
self.feed_forward_Linear1.qconfig = self.quant_16_linear
self.feed_forward_Linear2.qconfig = self.quant_16_linear
self.lm_head.qconfig = self.quant_16_linear