Use pytorch and reduce forloops of indexadd function

I want to reduce/remove the forloops used in customIndexAdd() that implements torch.index_add_() (it works only for dimension of -2 ) . Could anyone kindly help me with implementation of faster customIndexAdd() currently it takes 35seconds to execute.

import torch
import numpy as np
import time

def customIndexAdd(x1, index, tensor):
    s1,s2,s3,s4 = tensor.shape
    output_tensor = x1
    for i in range(s1):
        for j in range(s2):
            for k in range(s3):
                output_tensor[i][j][index[k]] += tensor[i][j][k]
    return output_tensor

# Create an array of sequential numbers starting from 1
sequential_numbers = np.arange(1, 2* 2* 352798* 2 + 1)

# Reshape the array to match the desired tensor shape
tensor = sequential_numbers.reshape(2, 2, 352798, 2)
t = torch.tensor(tensor).int()

values = torch.arange(1, 352796 // 2 + 1)

repeated_values = torch.repeat_interleave(values, repeats=2)
final_values =[torch.tensor([0]), repeated_values, torch.tensor([176399])])
index = final_values

x = torch.ones(2, 2, 176400, 2).int()
x.index_add_(-2, index, t)

x1 = torch.ones(2, 2, 176400, 2)

start = time.time()
out1 = customIndexAdd(x1, index, t)
end = time.time()
print(end - start)

print(torch.equal(x, out1))

I’m not sure I understand the use case completely, but why don’t you just use index_add_, which seems to work already?

While converting from torch to onnx i am facing issue due to scatterElements node which raises from index_add_() . So i need to eliminate index_add_() and implement custom index_add_() . Detailed issues are here [Conv-TasNet] Facing issue in converting Conv-TasNet model · Issue #447 · PINTO0309/onnx2tf · GitHub , Facing error while using onnx from scatterelements · Issue #106973 · pytorch/pytorch · GitHub

Also for my usage dim=-2 is enough so i am implementing only for it

these are the lines i need to replace without scaterElements node

This code takes 11 seconds but i need still faster

def customIndexAdd(x1, index, tensor):
    j = 0
    for i in range(len(index)):
        x1[:,:,index[i],:] += tensor[:,:,j,:]
        j += 1
    return x1