Scatter_add reduce output dimensions/shape

I have memory issues with the scatter_add function.

The standard implementation computes, e.g. self[i][index[i][j][k][m]][k][m] += src[i][j][k][m], where the output, will have shape I x Index_Max x K x M. This is often a large tensor, causing memory issues. I then sum over certain dimensions., e.g. if choose dimension 0,3 it will then compute a tensor shape 1x Index_Max x K x 1.

I am looking for an option, or workaround, to choose some of the output dimensions to have shape 1. In the example above, the scatter_add function would compute
self[0][index[i][j][k][m]][k][0] += src[i][j][k][m]
(or ignore dimensions 0,3 altogether), not computing the entire output tensor.

Thank you for the suggestions!

In that case what about:

  1. Perform sum reduction on the self: self=self.sum(0).sum(2)
  2. Reshape index and src to delete obsolete dimensions: index = torch.reshape(index.permute(3, 0, 1, 2), (-1, index.size(2))
    and src= torch.reshape(src.permute(3, 0, 1, 2), (-1, src.size(2))
  3. Call the scatter add function of the already reduced self: self = self.scatter_add(0, index, src)
  4. Unsqueeze the compressed dimensions: self = self.unsqueeze(0).unsqueeze(3)

Would that work/reduce memory print for you?


+10, thank you very much for the reply.