I’ve created a “BatchMLP” custom nn module
class BatchMLP(nn.Module): that takes in a list of MLPs, all of which have the same architecture but different parameters for vectorized inference, and constructs a single new mlp out of them. So the first layer of this BatchMLP object is all the first layers of each mlp in the list of mlps, the second layer of BatchMLP is all the second layers if each mlp in the list of mlps, etc.
During training, I want to be able to update some of the layers of this BatchMLP with new mlps. So for example, if I have a BatchMLP object made out of 100 mlps, and I want to replace the ith one, then I would need to do
batch_mlp_object.layers.weight[index_i] = new_mlp.layers.weight, batch_mlp_object.layers.weight[index_i] = new_mlp.layers.weight, ...
I would like to try to do this without copying the parameters from the new mlp into the BatchMLP object, if possible. So for example, if I store all the mlps inside this object as
self.mlps: List[nn.Module] = list_of_mlps, and I used this to construct an
nn.Sequential for the BatchMLP object, then updating
self.mlps[index_i] = new_mlp would automatically update the sequential layers because they both point to the same tensors in memory.
Is it possible to do this and if so, how?