Backpropagation not the same time complexity as forward pass in simple MLP with views/slicing

In a simple multilayer perceptron / fully connect NN, the forward pass is just a sequence of matrix multiplications composed with some nonlinearities. So the time complexity of the forward pass is basically just the time complexity of the matmul operations, which is polynomial time O(n^2.3) or something. The backward pass should have the same time complexity as the forward pass since it also consists of basically the same matmul operations.

I am working on implementing a technique for approximate matrix multiplication where I adaptively downsample the weight matrices of each layer of an MLP such that, no matter how big the layer weight matrices are, the down-sampled matrices will remain relatively constant or at most grows sub-linearly with the size of the original matrices.

Here’s a minimally working example:

import torch
import cProfile

def test1(x,W):
    return torch.sum(x @ W)

def test2(x,W,S):
    return torch.sum(x @ W[:,S])

dim1,dim2 = 784,60000
x = torch.randn(1,dim1);
W1 = torch.randn(dim1,dim2,requires_grad=True)
W2 = torch.randn(dim1,dim2,requires_grad=True);
S = torch.randint(0,dim2,(50,));"test1(x,W1)") #0.003 seconds -> 0.009 -> 0.016, appears to be scaling superlinearly"test2(x,W2,S)") #0.006 seconds -> 0.005 -> 0.006, runs in constant time as expected"test1(x,W1).backward()") # 0.016 seconds -> 0.078 -> 0.144, scaling superlinearly"test2(x,W2,S).backward()") #0.023 seconds -> 0.065 -> 0.220, also runs about as slow as test1

Why is the backward pass of test2 running as slowly as test1 and scaling superlinearly when I expect it to run in constant time like the forward pass of test2?