I have a complicated graph-neural-network-like pytorch model which is very hard for me to vectorize.
I am looking forward to using the new vmap functionality and might even start trying it out from the master branch on github, but in the meantime, I’m trying to understand more about vectorization in general.
I made the following short script to test how important batching is for speed in PyTorch.
When I use the batched version, it’s about 77 times faster.
(I used input and output dimensions of 256 and number of vectors = 10000)
import torch
import argparse
from time import time
parser = argparse.ArgumentParser(
description="A short program to test how much batching helps in PyTorch"
)
parser.add_argument("i", type=int, help="input vector dimension")
parser.add_argument("o", type=int, help="output vector dimension")
parser.add_argument("n", type=int,
help="number of vectors to multiply matrix by"
)
parser.add_argument("--batched", action="store_true")
args = parser.parse_args()
layer = torch.nn.Linear(args.i, args.o)
xs = torch.randn(args.n, args.i, requires_grad=False)
print("start")
t1 = time()
if args.batched:
ys = layer(xs)
else:
for x in xs:
y = layer(x)
t2 = time()
print("end")
print(f"time elapsed: {t2-t1}")
Basically, my question comes down to the following:
What are all of the factors that contribute to making batched code take 1/77th the time? (even on CPU)