Computing mean of list of generator objects

// Initialize 2-D array, each entry is some neural network
phi = [[None] * n for _ in range(m)]
for i in range(m):
    for j in range(n):
        phi[i][j] = NeuralNetwork()

// Let k, i be arbitrary indices
p1 = torch.nn.utils.parameters_to_vector(phi[k][i - 1].parameters())
p2 = torch.nn.utils.parameters_to_vector(mean of phi[:][i-1])

I want to basically compute the mean squared error between the parameters phi[k][i-1] and average of the entire column phi[:][i-1] i.e. ((p1 - p2)**2).sum() I tried in the following way:

tmp = [x.parameters() for x in self.phi[:][i - 1]]
mean_params = torch.mean(torch.stack(tmp), dim=0)
p2 = torch.nn.utils.parameters_to_vector(mean_params)

But this doesn’t work out because tmp is a list of generator objects. More specifically, I guess my problem is to compute the mean from that generator object.

Three things I can think of:

  • So there isn’t anything per se wrong with making a list of the parameters. In particular, it will loop over the parameters and create a new list, but you won’t be copying tensors a lot but only references to them. That might make it easier to just express what you want with for loops, even if it might not be the solution with maximal computational efficiency. For example, you could use parameters_to_vector inside the list comprehension (or parameters_to_vector(...)).
  • You can nest list comprehensions. For example you could use
    tmp = [[something(p) for p in x[i - 1].parameters()] for x in self.phi]
    or so. I would recommend to exercise judgement whether or not this is very understandable and also add commentary what’s going on if you do.
    A way to get easier to understand things might be using named functions - similar to parameters_to_vector to achieve the things you want to do.
    You might need nested cat/stack if you want to build a “matrix of means”.
  • If you don’t need gradients, you likely want to detach the parameters before computing with them.

As an aside, some care probably is to be taken with parameters. While I think they will usually come up in the same order for different instances of the same network, I’m not entirely sure that it is a hard guarantee (it might fail if people do really funny stuff in instantiation or use dict structures in older Python, probably not much risk with standard PyTorch modules these days). It certainly isn’t an obvious (i.e. explicit) guarantee to me. That said, it’s not clear to me that it matters at all for your use case (i.e. if you’re taking statistics that don’t depend on the order of the parameters, you should be fine anyways) and it could be that I’m overthinking this.

Best regards