Creating a modifiable MLP of MLPs

Hello,

I’ve created a “BatchMLP” custom nn module class BatchMLP(nn.Module): that takes in a list of MLPs, all of which have the same architecture but different parameters for vectorized inference, and constructs a single new mlp out of them. So the first layer of this BatchMLP object is all the first layers of each mlp in the list of mlps, the second layer of BatchMLP is all the second layers if each mlp in the list of mlps, etc.

During training, I want to be able to update some of the layers of this BatchMLP with new mlps. So for example, if I have a BatchMLP object made out of 100 mlps, and I want to replace the ith one, then I would need to do batch_mlp_object.layers[0].weight[index_i] = new_mlp.layers[0].weight, batch_mlp_object.layers[1].weight[index_i] = new_mlp.layers[1].weight, ...

I would like to try to do this without copying the parameters from the new mlp into the BatchMLP object, if possible. So for example, if I store all the mlps inside this object as self.mlps: List[nn.Module] = list_of_mlps, and I used this to construct an nn.Sequential for the BatchMLP object, then updating self.mlps[index_i] = new_mlp would automatically update the sequential layers because they both point to the same tensors in memory.

Is it possible to do this and if so, how?

Thanks!