Training a linear layer with a 2D input

Alright, that makes sense!

Can you point me to the pytorch docs that implies the gradients of the first layers (before aggregation) will be averaged?

Do you have any superior idea for doing this kind of aggregation?