How to decompose deep learning trainable parameter into two parts for Pytorch?

I got this problem when I read the paper ‘Federated Semi-Supervised Learning with Inter-Client Consistency & Disjoint Learning’. I am quite wondering about the disjoint learning. The author said ‘For a given model, it should decompose into two parts where model = a + b(each model’s parameters equal to the sum of the same location parameter a + the same location parameter b). When we train the model on dataset A, we only update a and hold b constant. When we train the model on dataset B, we only update b and hold a constant.’ How can we do this with pytorch?

Thanks so much for your help!!!