My model gets worse when I use einsum?

I have two versions of my network. The original version contains this (it’s simplified slightly for ease of reading):

nChan=1
nFilter=3
timeBins=4
nFreq=5

#setup:
L = nn.Parameter(torch.randn(nFreq,nFilter),requires_grad=False)
E = nn.Parameter(torch.randn(nFilter,nChan))

#forward():
F = torch.multiply(torch.sigmoid(E[:,0]),L)
x1 = torch.matmul(x[:,:,:,0], F)

The ‘improved’ version has the two forward steps merged into one:

x=torch.einsum('btrc,ri,ic->btic',x,L,torch.sigmoid(E))

Running the two versions on the same input gives the same result (up to float precision):

x0=torch.randn(10,timeBins,nFreq,1)
x=x0
F = torch.multiply(torch.sigmoid(E[:,0]),L)
x1 = torch.matmul(x[:,:,:,0], F) # filtering
x1 = torch.reshape(x1, [-1, timeBins, nFilter])       

x=x0
x=torch.einsum('btrc,ri,ic->btic',x,L,torch.sigmoid(E))
x2=torch.reshape(x,(-1,timeBins,nFilter*nChan))
print(x1-x2)

tensor([[[ 0.0000e+00, 0.0000e+00, 2.9802e-08],
[ 1.1921e-07, 2.9802e-08, -5.9605e-08],
[ 5.9605e-08, -1.1921e-07, 0.0000e+00],
[ 0.0000e+00, -1.1921e-07, -2.9802e-08]],

    [[-7.4506e-09,  0.0000e+00,  0.0000e+00],
     [ 0.0000e+00,  0.0000e+00, -3.7253e-08],
     [-2.9802e-08, -5.9605e-08,  5.9605e-08],
     [ 0.0000e+00, -5.9605e-08, -5.9605e-08]],

    [[ 0.0000e+00,  0.0000e+00,  0.0000e+00],
     [-2.9802e-08, -1.4901e-08,  2.9802e-08],
     [ 0.0000e+00, -1.1921e-07,  5.9605e-08],
     [ 0.0000e+00,  1.1921e-07, -1.1921e-07]],

    [[ 0.0000e+00, -5.9605e-08, -2.3842e-07],
     [ 0.0000e+00,  2.9802e-08, -2.3842e-07],
     [ 0.0000e+00, -1.4901e-08,  0.0000e+00],
     [ 0.0000e+00,  2.3842e-07, -1.1921e-07]],

    [[ 2.3842e-07,  3.7253e-08, -4.7684e-07],
     [-1.1176e-08,  0.0000e+00, -1.1921e-07],
     [ 0.0000e+00, -5.9605e-08, -1.7881e-07],
     [-2.9802e-08,  5.9605e-08,  0.0000e+00]],

    [[-5.9605e-08, -7.4506e-09, -1.1921e-07],
     [ 0.0000e+00,  0.0000e+00,  1.1921e-07],
     [ 5.9605e-08,  5.9605e-08, -2.9802e-08],
     [-2.9802e-08,  5.9605e-08, -1.1921e-07]],

    [[-2.9802e-08,  2.9802e-08,  0.0000e+00],
     [ 5.9605e-08,  2.3842e-07,  5.9605e-08],
     [-2.2352e-08,  0.0000e+00,  0.0000e+00],
     [ 0.0000e+00,  2.9802e-08,  5.9605e-08]],

    [[ 0.0000e+00, -1.4901e-08,  0.0000e+00],
     [-5.9605e-08,  0.0000e+00, -8.8476e-09],
     [ 0.0000e+00,  0.0000e+00,  1.4901e-08],
     [ 7.4506e-09, -5.9605e-08,  0.0000e+00]],

    [[ 0.0000e+00,  0.0000e+00, -2.3842e-07],
     [-1.1921e-07, -1.1921e-07,  0.0000e+00],
     [ 2.0023e-08,  0.0000e+00, -1.2547e-07],
     [ 0.0000e+00, -2.2352e-08, -5.9605e-08]],

    [[ 1.4901e-08,  5.9605e-08, -1.1921e-07],
     [ 0.0000e+00,  0.0000e+00,  7.4506e-09],
     [ 0.0000e+00,  2.3842e-07,  5.9605e-08],
     [-5.9605e-08,  2.2352e-08,  0.0000e+00]]], grad_fn=<SubBackward0>)

If I change datatypes to double everywhere, the errors move down to e-16.

However, after making this change, my performance has decreased by around 2%. On the specific problem I am working on, this is the difference between being ‘state of the art’ and ‘OK’. I have tried running both versions with a range of different seeds, learning rates etc., and it is very clear that the einsum version hits a ceiling a few percentage points below that of the original version.

If it makes a difference, I am happy to share the ‘derivation’ for why this is the correct einsum expression, if the size of the differences is not convincing enough.

What could be going on?

PyTorch’s einsum - in contrast to numpy’s - isn’t optimized for speed but breaks down the einsum into things that are ultimately (batched) matrix multiplications. The idea I had in mind when implementing it was that is was a mere convenience thing. This turned out to be wrong, but no-one stepped up to provide a genuinely faster reimplementation yet although some cases were identified where the batched matrix multiplication could be replaced by eg a plain matmul. (I did experiment with a different reduction kernel, but it seemed to be faster on some, slower on other inputs, not least because it is actually hard to match matrix multiplication routines shipped in the vendor libraries.)

So your observation is quite plausible.

Best regards

Thomas

sorry, I think I have expressed myself poorly.

I do not mean that the calculation is slower. In fact einsum is faster (around 25% on average, I believe). My problem is that on average, the accuracy of my predictions (it’s a classification task), is slightly worse.

so, I get different outputs, and on average, those different outputs are slightly more wrong than what the original version returns.
is it possible that the gradient is calculated slightly differently, or some such thing?

Oh, sorry, I didn’t get that right.
Well, so the gradient would also be subject to numerical accuracy differences similar to those in the forward. If you want to be certain, you could extend your test program to run torch.autograd.gradcheck.
Those are legitimate and, unfortunate as it is, it would seem that the model training is quite brittle (which is the case in deep learning more often than it should).

Best regards

Thomas

1 Like

Hi again :slight_smile:
Is this what you suggest?

def func1(x,E,L):
    F = torch.multiply(torch.sigmoid(E[:,0]),L)
    return torch.matmul(x[:,:,:,0], F)

def func2(x,E,L):
    return torch.einsum('btrc,ri,ic->btic',x,L,torch.sigmoid(E))

torch.autograd.gradcheck(func1,(x0,E,L))

torch.autograd.gradcheck(func2,(x0,E,L))

I have changed the precisions to double for everything, and it seems that at least according to the standard settings, both versions have the correct gradients. I have also tried to change atol and rtol, but can not seem to find values where one fails but not the other.

This is a bit annoying. both because I adore the idea of einsum (my background is in physics), and because it gives a very nice way of generalizing my code beyond nChan=1.

it seems that for now, I should just stay away from einsum?

You do whatever keeps PyTorch most useful for you. But beware: The different implementations of matmul can and will have numerical precision differences, too.
The mathematician in me is convinced that we’ll eventually have more robust procedures, but clearly we aren’t there yet.

Best regards

Thomas