Torchao get RuntimeError

I firsrt use the Trochao to quantize my model, like that:

import torch
import random
import numpy as np
import torch.nn as nn
from torchao.quantization import quantize_, Int4WeightOnlyConfig
from torchao.quantization.qat import QATConfig, IntxFakeQuantizeConfig

def setup_seed(seed):
     torch.manual_seed(seed)
     torch.cuda.manual_seed_all(seed)
     np.random.seed(seed)
     random.seed(seed)
     torch.backends.cudnn.deterministic = True
setup_seed(20)

class SimpleModule(nn.Module):
    def __init__(self, input_dim, output_dim):
        super().__init__()
        self.fc = nn.Sequential(
            nn.Linear(input_dim, 1024),
            nn.Linear(1024, output_dim)
        )
    def forward(self, x):
        return self.fc(x)

simple_model = SimpleModule(32, 32)
simple_model = simple_model.to("cuda")


print(simple_model)

# qat_config = QATConfig(
#     activation_config=IntxFakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False,),
#     weight_config=IntxFakeQuantizeConfig(torch.int4, group_size=32, is_symmetric=True,),
#     step="prepare",
# )
base_config = Int4WeightOnlyConfig(group_size=32)
qat_config = QATConfig(base_config, step= 'prepare')
quantize_(simple_model, qat_config)

print(simple_model)

# simple_model.compile()

optimizer = torch.optim.AdamW(simple_model.parameters(), lr=1e-4)
inputs = torch.randn(32, 32, device='cuda')
target = torch.randn(32, 32, device='cuda')
for i in range(1):
    optimizer.zero_grad()
    output = simple_model(inputs)
    loss = nn.functional.mse_loss(output, target)

    loss.backward()
    optimizer.step()

convert_qat_config = QATConfig(
    # activation_config=IntxFakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False,),
    # weight_config=IntxFakeQuantizeConfig(torch.int4, group_size=32, is_symmetric=True,),
    # base_config,
    step="convert",
)

quantize_(simple_model, convert_qat_config)
print(simple_model)

but get some error:

RuntimeError: shape '[1024, -1, 128]' is invalid for input of size 32768

and then try to print the quantize config get:

# the quantize prepar model
FakeQuantizer(Int4WeightFakeQuantizeConfig(group_size=128, activation_dtype=torch.bfloat16))
# the quantizeconfig
Int4WeightFakeQuantizeConfig(group_size=128, activation_dtype=torch.bfloat16)

i curious that i have set the quantize config group_size=32 but why print model the group_size=128