Backward operation does not scale as expected

Hello everyone,

I am trying to understand why the computational time of the backward pass of a given operation do not scale as expected.

Consider the following code

J = 1000
S = 10
L = 10
M = 10
J, S, L, M = 1000, 10, 1, 10
N = M * J
p = torch.nn.Parameter(torch.zeros(L, S, N), requires_grad=True)
p_prod = torch.stack([torch.prod((p[:, :, range(i * M, (i + 1) * M - 1)]), 1) for i in range(J)])
out = p_prod.sum()

start_time = time.time()
out.backward(inputs = [p])
print(“— %s seconds —” % (time.time() - start_time))

which essentially compute the products of p across several indexes and separately across the first and second dimension.

For L = 1, the operations takes 0.19535231 seconds, but for L = 10, the operation takes 5.62876 seconds, which is more than x10 what I was expecting it to take.

Is there a reason why the operation does not scale as expected? Moreover, is there a way to speed up the computation of the gradient when I perform operations across several dimensions (either by computing the product in a different way or by using parallelisation)?

Thanks in advance