Loading quantized model on clients in federated learning setup

Hello! I would like to perform federated learning with quantization of weights, I face some difficulties though with loading model back on “clients” after quantization, could you point out what am I doing wrong and if that what I want to achieve is possible with current pytorch implementation? When I run the code below:

from torch import nn

from src.data_utils import get_data_loaders, get_model_bits
from src.training import *

BATCH_SIZE = 30
NUM_CLIENTS = 1
ACCURACY_THRESHOLD = 93
IID_SPLIT = False

# Load data
train_loaders, _, test_loader = get_data_loaders(BATCH_SIZE, NUM_CLIENTS, percentage_val=0, iid_split=IID_SPLIT)

# Initialize all clients
clients = [Client(train_loader) for train_loader in train_loaders]

# Set seed for the script
torch.manual_seed(clients[0].seed)

testing_accuracy = 0
bits_transferred = 0
num_rounds = 0
bits_conserved = 0

centralServer = Client(test_loader)




while testing_accuracy < ACCURACY_THRESHOLD:
    num_rounds += 1
    print("Communication Round {0}".format(num_rounds))

    if num_rounds > 0:
        # Load server weights onto clients
        for client in clients:
            with torch.no_grad():
                # Calculate number of bits in full server model
                float_model_bits = get_model_bits(centralServer.model)
                # Quantize server's model
                centralServer.model.qconfig = torch.quantization.default_qconfig
                quantized_model = torch.quantization.prepare(centralServer.model)
                quantized_model = torch.quantization.convert(quantized_model)
                bits_transferred = get_model_bits(quantized_model)
                # Calculate how many bits we saved
                bits_conserved = float_model_bits - bits_transferred
                qm = quantized_model.state_dict()
                to_load = {}
                to_keep = ["conv1.weight", "conv1.bias", "conv2.weight", "conv2.bias"]
                for name, param in qm.items():
                    # print(name)
                    if "_packed_params" in name:
                        splitted = name.split('.')
                        key = splitted[0] + '.' + splitted[2]
                        to_load[key] = param
                    if name in to_keep:
                        to_load[name] = param
                # Distribute quantized model on clients
                client.model.load_state_dict(to_load, strict=False)
    # Perform E local training steps for each client
    for client_idx, client in enumerate(clients):
        print("Training client {0}".format(client_idx))
        for epoch in range(1, client.epochs + 1):
            train(client, epoch)

I get an error that:

RuntimeError: Error(s) in loading state_dict for CNN:
	While copying the parameter named "conv1.weight", whose dimensions in the model are torch.Size([32, 1, 5, 5]) and whose dimensions in the checkpoint are torch.Size([32, 1, 5, 5]).
	While copying the parameter named "conv2.weight", whose dimensions in the model are torch.Size([64, 32, 5, 5]) and whose dimensions in the checkpoint are torch.Size([64, 32, 5, 5]).
	While copying the parameter named "fc1.weight", whose dimensions in the model are torch.Size([512, 1024]) and whose dimensions in the checkpoint are torch.Size([512, 1024]).
	While copying the parameter named "fc2.weight", whose dimensions in the model are torch.Size([10, 512]) and whose dimensions in the checkpoint are torch.Size([10, 512]).

As you can see I try to modify the quantized model state_dict as it has keys like “fc1._packed_params.weight” which I assume are the quantized weights, so I just try to change the names in state dict to get the same ones as in original model. Any help how to walk around this or maybe it is not possible at the moment would be very appreciated. I think it is the same or related problem to the one described here: Can't load model after dynamic quantization, answer was not found I think… Thank you in advance for your answers!