Batching Question

I have a complicated graph-neural-network-like pytorch model which is very hard for me to vectorize.

I am looking forward to using the new vmap functionality and might even start trying it out from the master branch on github, but in the meantime, I’m trying to understand more about vectorization in general.

I made the following short script to test how important batching is for speed in PyTorch.
When I use the batched version, it’s about 77 times faster.
(I used input and output dimensions of 256 and number of vectors = 10000)

import torch
import argparse
from time import time

parser = argparse.ArgumentParser(
 description="A short program to test how much batching helps in PyTorch"
)
parser.add_argument("i", type=int, help="input vector dimension")
parser.add_argument("o", type=int, help="output vector dimension")
parser.add_argument("n", type=int, 
  help="number of vectors to multiply matrix by"
)
parser.add_argument("--batched", action="store_true")
args = parser.parse_args()

layer = torch.nn.Linear(args.i, args.o)
xs = torch.randn(args.n, args.i, requires_grad=False)

print("start")
t1 = time()
if args.batched:
    ys = layer(xs)
else:
    for x in xs:
        y = layer(x)
t2 = time()
print("end")
print(f"time elapsed: {t2-t1}")

Basically, my question comes down to the following:
What are all of the factors that contribute to making batched code take 1/77th the time? (even on CPU)

For loops add the Python overhead in each iteration and call into the dispatcher for a small workload.
A lot of operations are using vectorization inside, which won’t be fully used.
Note that this is not PyTorch-specific and can also be seen in e.g. numpy:

def fun1(x):
    out = []
    for x_ in x:
        out.append(np.sqrt(x_))
    out = np.stack(out)
    return out

def fun2(x):
    out = np.sqrt(x)
    return out

x = np.random.rand(1000, 100)
%timeit out1 = fun1(x)
%timeit out2 = fun2(x)

For loops add the Python overhead in each iteration and call into the dispatcher for a small workload.

The for loops add very little overhead as far as I could tell from profiling using line_profiler.
Is the pytorch dispatch really enough to explain it being 77 times slower?

A lot of operations are using vectorization inside, which won’t be fully used.

What does vectorization really mean at a low level?
Do you mean that pytorch uses avx2 style SIMD parallelism on CPU?