I firsrt use the Trochao to quantize my model, like that:
import torch
import random
import numpy as np
import torch.nn as nn
from torchao.quantization import quantize_, Int4WeightOnlyConfig
from torchao.quantization.qat import QATConfig, IntxFakeQuantizeConfig
def setup_seed(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
torch.backends.cudnn.deterministic = True
setup_seed(20)
class SimpleModule(nn.Module):
def __init__(self, input_dim, output_dim):
super().__init__()
self.fc = nn.Sequential(
nn.Linear(input_dim, 1024),
nn.Linear(1024, output_dim)
)
def forward(self, x):
return self.fc(x)
simple_model = SimpleModule(32, 32)
simple_model = simple_model.to("cuda")
print(simple_model)
# qat_config = QATConfig(
# activation_config=IntxFakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False,),
# weight_config=IntxFakeQuantizeConfig(torch.int4, group_size=32, is_symmetric=True,),
# step="prepare",
# )
base_config = Int4WeightOnlyConfig(group_size=32)
qat_config = QATConfig(base_config, step= 'prepare')
quantize_(simple_model, qat_config)
print(simple_model)
# simple_model.compile()
optimizer = torch.optim.AdamW(simple_model.parameters(), lr=1e-4)
inputs = torch.randn(32, 32, device='cuda')
target = torch.randn(32, 32, device='cuda')
for i in range(1):
optimizer.zero_grad()
output = simple_model(inputs)
loss = nn.functional.mse_loss(output, target)
loss.backward()
optimizer.step()
convert_qat_config = QATConfig(
# activation_config=IntxFakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False,),
# weight_config=IntxFakeQuantizeConfig(torch.int4, group_size=32, is_symmetric=True,),
# base_config,
step="convert",
)
quantize_(simple_model, convert_qat_config)
print(simple_model)
but get some error:
RuntimeError: shape '[1024, -1, 128]' is invalid for input of size 32768
and then try to print the quantize config get:
# the quantize prepar model
FakeQuantizer(Int4WeightFakeQuantizeConfig(group_size=128, activation_dtype=torch.bfloat16))
# the quantizeconfig
Int4WeightFakeQuantizeConfig(group_size=128, activation_dtype=torch.bfloat16)
i curious that i have set the quantize config group_size=32 but why print model the group_size=128