I would like to know how to parallelize this algorithm.
In my current implementation :
- I am looping over
x
andy
like in the pseudo code, and it takes too much time whenn
andm
are high. - I am using
torch.nn.utils.parameters_to_vector
to convert my parameters to vector, then disrupting and converting back to parameter withtorch.nn.utils.vector_to_parameters
. - I am using pytorch lightning for evaluation,
loss = trainer.validate(model, dataloader)
It is even slower in this second case.