Is it possible to write this without a for loop (to speed it up)

def forward(x, dur_cnts):
    N = x.size(0)
    T = x.size(1)
    durs, counts = dur_cnts
    idx = 0
    joined = to.zeros((N, self.dim_in), dtype=to.float32, device=x.device)
    for i, (dur, count) in enumerate(zip(durs, counts)):
        x_part = x[idx: idx + count, :dur]
        joined[idx: idx + count] = x_part.sum(dim=1)
        idx += count

The input is size (N, T, C), and I’m taking differently sized matrices (<<T) from it and then summing across the T dimension.

I think it would be possible with some sort of gather_add operation


I’m afraid that if you have general indices and counts, you won’t be able to do this in a single op.
You will need a custom kernel that implements this.

1 Like