Hello,
I have a question about the computation of loss when using either DP or DDP. I have a particular loss function (physics related) which is highly sensitive to batch size. Thus, I am wondering if there is a way to use a distributed method which aggregates all network outputs onto a single device BEFORE computing the loss. The computation of the loss is analytically more effective when using more data, i.e it involves estimating a linear operator constructed from covariance matrices. It’s not hard to show that computing this loss separately for each scattered batch and then aggregating the is not desirable for the application. While I understand that under normal circumstances what I am shooting for doesn’t make much sense and is not desirable computationally, I am certain that for this particular application, aggregation of network outputs before estimating the loss would be highly beneficial. Is there a way to do this? Answers explaining that I don’t want to do this in the first place are not appreciated, my motivation is entirely application based and outside the scope of this forum (physics / applied math).