Parallel sum specific array values using torch.where()

Hello, I need to manually accumulate the parameter.grad attributes of my network based on the values of another tensor. I’m working with the following variables:

  1. I have a bit-string array of sampled bits m, of size (M,) (where M is the number of samples)
  2. I also have two tensors of gradients, grad1 and grad2. grad1[i] and grad2[i] represent the gradients corresponding to sampled bits of -1 and 1 respectively. Each specific gradient tensor located at i is a different size based on the parameter being modified. Ex. If M=5 (5 sampled bits), I will have grad1 and grad2 will each be tensors of 5 parameter gradient matrices.
  3. I have total_grad, which takes the same shape as grad1 and grad2. total_grad is increased by either grad1 or grad2.

I’m trying to parallelize the process of accumulating total_grad based on m. If m[i] > 0, I want to add grad1[i] to total_grad[i] for all samples. Otherwise, I want to add grad2[i]. However, I don’t want to do this with a for loop (for obvious efficiency reasons), so I tried to parallelize this using torch.where():

for param in range(params)
    total_grads[param][:] += torch.where(m[:] > 0, grad1[param][:], grad2[param][:])

This code appears to work for the first list of parameter gradients (of size (M, 4, 2)), but not for the second (of size (M,4)). I’m receiving the following Runtime error:

The size of tensor a (2) must match the size of tensor b (4) at non-singleton dimension 1

Can someone help me understand whats going on here, or a better way to do this? Thanks!

Just figured it out! I had to re-shape m to be compatible with the shape of grad1.

reshaped_m = m.reshape(m.shape + (1,)*(grads_per_param[param].ndim-1))