I have two versions of my network. The original version contains this (it’s simplified slightly for ease of reading):
nChan=1
nFilter=3
timeBins=4
nFreq=5
#setup:
L = nn.Parameter(torch.randn(nFreq,nFilter),requires_grad=False)
E = nn.Parameter(torch.randn(nFilter,nChan))
#forward():
F = torch.multiply(torch.sigmoid(E[:,0]),L)
x1 = torch.matmul(x[:,:,:,0], F)
The ‘improved’ version has the two forward steps merged into one:
x=torch.einsum('btrc,ri,ic->btic',x,L,torch.sigmoid(E))
Running the two versions on the same input gives the same result (up to float precision):
x0=torch.randn(10,timeBins,nFreq,1)
x=x0
F = torch.multiply(torch.sigmoid(E[:,0]),L)
x1 = torch.matmul(x[:,:,:,0], F) # filtering
x1 = torch.reshape(x1, [-1, timeBins, nFilter])
x=x0
x=torch.einsum('btrc,ri,ic->btic',x,L,torch.sigmoid(E))
x2=torch.reshape(x,(-1,timeBins,nFilter*nChan))
print(x1-x2)
tensor([[[ 0.0000e+00, 0.0000e+00, 2.9802e-08],
[ 1.1921e-07, 2.9802e-08, -5.9605e-08],
[ 5.9605e-08, -1.1921e-07, 0.0000e+00],
[ 0.0000e+00, -1.1921e-07, -2.9802e-08]],[[-7.4506e-09, 0.0000e+00, 0.0000e+00], [ 0.0000e+00, 0.0000e+00, -3.7253e-08], [-2.9802e-08, -5.9605e-08, 5.9605e-08], [ 0.0000e+00, -5.9605e-08, -5.9605e-08]], [[ 0.0000e+00, 0.0000e+00, 0.0000e+00], [-2.9802e-08, -1.4901e-08, 2.9802e-08], [ 0.0000e+00, -1.1921e-07, 5.9605e-08], [ 0.0000e+00, 1.1921e-07, -1.1921e-07]], [[ 0.0000e+00, -5.9605e-08, -2.3842e-07], [ 0.0000e+00, 2.9802e-08, -2.3842e-07], [ 0.0000e+00, -1.4901e-08, 0.0000e+00], [ 0.0000e+00, 2.3842e-07, -1.1921e-07]], [[ 2.3842e-07, 3.7253e-08, -4.7684e-07], [-1.1176e-08, 0.0000e+00, -1.1921e-07], [ 0.0000e+00, -5.9605e-08, -1.7881e-07], [-2.9802e-08, 5.9605e-08, 0.0000e+00]], [[-5.9605e-08, -7.4506e-09, -1.1921e-07], [ 0.0000e+00, 0.0000e+00, 1.1921e-07], [ 5.9605e-08, 5.9605e-08, -2.9802e-08], [-2.9802e-08, 5.9605e-08, -1.1921e-07]], [[-2.9802e-08, 2.9802e-08, 0.0000e+00], [ 5.9605e-08, 2.3842e-07, 5.9605e-08], [-2.2352e-08, 0.0000e+00, 0.0000e+00], [ 0.0000e+00, 2.9802e-08, 5.9605e-08]], [[ 0.0000e+00, -1.4901e-08, 0.0000e+00], [-5.9605e-08, 0.0000e+00, -8.8476e-09], [ 0.0000e+00, 0.0000e+00, 1.4901e-08], [ 7.4506e-09, -5.9605e-08, 0.0000e+00]], [[ 0.0000e+00, 0.0000e+00, -2.3842e-07], [-1.1921e-07, -1.1921e-07, 0.0000e+00], [ 2.0023e-08, 0.0000e+00, -1.2547e-07], [ 0.0000e+00, -2.2352e-08, -5.9605e-08]], [[ 1.4901e-08, 5.9605e-08, -1.1921e-07], [ 0.0000e+00, 0.0000e+00, 7.4506e-09], [ 0.0000e+00, 2.3842e-07, 5.9605e-08], [-5.9605e-08, 2.2352e-08, 0.0000e+00]]], grad_fn=<SubBackward0>)
If I change datatypes to double everywhere, the errors move down to e-16.
However, after making this change, my performance has decreased by around 2%. On the specific problem I am working on, this is the difference between being ‘state of the art’ and ‘OK’. I have tried running both versions with a range of different seeds, learning rates etc., and it is very clear that the einsum version hits a ceiling a few percentage points below that of the original version.
If it makes a difference, I am happy to share the ‘derivation’ for why this is the correct einsum expression, if the size of the differences is not convincing enough.
What could be going on?