I have memory issues with the scatter_add function.

The standard implementation computes, e.g. self[i][index[i][j][k][m]][k][m] += src[i][j][k][m], where the output, will have shape I x Index_Max x K x M. This is often a large tensor, causing memory issues. I then sum over certain dimensions., e.g. if choose dimension 0,3 it will then compute a tensor shape 1x Index_Max x K x 1.

I am looking for an option, or workaround, to choose some of the output dimensions to have shape 1. In the example above, the scatter_add function would compute

self[0][index[i][j][k][m]][k][0] += src[i][j][k][m]

(or ignore dimensions 0,3 altogether), not computing the entire output tensor.

Thank you for the suggestions!