Hi @Neel_Nanda,
Can you share a minimal reproducible example?
-
When you compare
torch.einsum
withnn.Linear
make sure you havenn.Linear(bias=False)
otherwise the operations aren’t equivalent. -
When you are measuring times for code snippets, make sure that you synchronize torch (
torch.cuda.synchronize
) before callingtime.time()
otherwise you won’t record the whole runtime of an operation but only its call (especially if ran on the GPU). There’s more info about it here.