How to copy my model parameters to another model

Hi
I have a federated learning scenario in which i want to send my cloud model parameters to different clients. i tried different ways. i did it with

model_dict[name_of_models[i]].conv1.weight.data = main_model.conv1.weight.detach().clone())

and it was working but as i saw here its better to not use .data in my code.
so i change it to

model_dict[name_of_models[i]].conv1.weight = nn.Parameters(model.conv1.weight.detach().clone())

but when i do this my clients model stop updating. i think that’s because i changed the parameters i referenced in their optimizer.
now i’m doing it with

with torch.no_grad():
for i in range(number_of_clients):
state_dict = model_dict[name_of_models[i]].state_dict()

        state_dict['conv1.weight'] = main_model.conv1.weight.detach().clone()
        state_dict['conv2.weight'] = main_model.conv2.weight.detach().clone()
        
        state_dict['conv1.bias'] = main_model.conv1.bias.detach().clone()
        state_dict['conv2.bias'] = main_model.conv2.bias.detach().clone()
        
        state_dict['fc1.weight'] = main_model.fc1.weight.detach().clone()
        state_dict['fc2.weight'] = main_model.fc2.weight.detach().clone()
        state_dict['fc3.weight'] = main_model.fc3.weight.detach().clone()
    
        state_dict['fc1.bias'] = main_model.fc1.bias.detach().clone()
        state_dict['fc2.bias'] = main_model.fc2.bias.detach().clone()
        state_dict['fc3.bias'] = main_model.fc3.bias.detach().clone()
        
        model_dict[name_of_models[i]] = model_dict[name_of_models[i]].load_state_dict(state_dict)

but when i do this i get the error ‘_IncompatibleKeys’ object has no attribute ‘train’ while training my model.
i will appreciate if anyone give me an advice on how to do this properly.
thanks

Since you mentioned federated learning, shouldn’t the data transfer be in a distributed environment? How about the following code snippet?

    # Flattening all the parameters of the cloud model into a contiguous buffer to prepare for data transfer.
    flat_params = torch.cat([p.data.view(-1) for p in model.parameters()])

    # broadcast the tensors or call process group send/recv?
    ...
    
    # Copy the parameters to the client model layer by layer.
    offset = 0
    for p in module.parameters():
        p.data = flat_params[offset : offset + p.numel()].view_as(p)
        offset += p.numel()

Thanks for replying
First of all i thought its not recommended to use .data in our code.
Second, sorry i didn’t understand what that offset is supposed to do. could you explain more?
in my scenario i have 10 clients or node and each one of them have their own model with the same architecture. my code works when i update client models layer by layer with the code below

model_dict[name_of_models[i]].conv1.weight.data = main_model.conv1.weight.detach().clone())

as i heard that .data attribute can cause silent error in my code, i’m looking for some alternate way to do it.

It’s used for unpacking a flattened tensor into a set of tensors layer by layer.

as i heard that .data attribute can cause silent error in my code, i’m looking for some alternate way to do it.

Hmm, not sure why it can have silent error in a distributed environment.

To set Tensor storage, the recommended way is to use Tensor.set_() ​ function: torch.Tensor.set_ — PyTorch 1.9.0 documentation

Learned from @albanD :slight_smile:

1 Like

You can refer to the following example.

You can change the model weights to a Tensor by parameters_to_vector function, communicate between ranks, revert that to weights by vector_to_parameters funciton.

I’m using a collective communication function to synchronize weight parameters on all GPUs but you can change that to point-to-point communication functions such as torch.distributed.send and torch.distributed.recv.

docs: torch.distributed

import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel

from torch.nn.utils import parameters_to_vector, vector_to_parameters

model = ...

# synchronize model parameters across nodes
vector = parameters_to_vector(model.parameters())

dist.broadcast(vector, 0)   # broadcast parameters to other processes
if dist.get_rank() != 0:
    vector_to_parameters(vector, model.parameters())