Hi,
I’d like to replace a for-loop for multiple NNs with something like matrix operation GPU usage.
(Note: the following code is conceptual; would not be runnable)
For example, I have a bunch of NNs, which are contained in a torch.nn.ModuleList
.
That is,
my_list = torch.nn.ModuleList()
for _ in range(N):
my_list.append(construct_my_NN())
Now, in the method forward
, I use the output of all NNs generated before.
def forward(self, xs):
d = xs.shape[0]
ys = torch.empty(d, 1)
for i in range(N):
ys[i] = my_list[i](xs)
return my_custom_func(ys)
Note that the torch.sum
can be a different custom function.
I think this makes forward method very slow.
Can I replace the for-loop with a single operation (something like matrix operation)?
Or, any ideas to detour this issue?
EDIT: if one would replace the multiple NNs with a single NN with high-dimensional outputs.
For example, the above forward method is replaced with
def forward(self, xs):
d = xs.shape[0]
ys = huge_NN(xs)
return my_custom_func(ys)
I think the two cases may be different in terms of numerical sensitivity or training performance.
Note that I actually implemented the two cases and found that, of course, they showed different learning progress.
Can you guys give me an intuition for this?