[float32 precision] Are torch.einsum and F.conv2d equivalent? Different results with torch.einsum and F.conv2d

I need to use torch.einsum to do the convolution instead of using torch.nn.functional.conv2d.

:books: In short, I got 16 images with the size of (3, 4, 4) and 16 kernels with the size of (3, 4, 4). I need to apply one kernel to one corresponding image and get one output (one data point).

import torch
import torch.nn.functional as F

weight = torch.normal(10**-5, 10**-6, (16, 3, 4, 4))
bias = torch.normal(10**-5, 10**-6, (16,1))

img = torch.randn(16, 3, 4, 4)

I compared the convolutional operation with einsum and conv2d:


F.conv2d(img[0], weight[0].unsqueeze(0), bias[0]).item()

I got 4.617268132278696e-05

But with

(torch.einsum('chw,chw->', [img[0], weight[0]]) + bias[0]).item()

I got 4.617268859874457e-05.

All data dtype = torch.float32.

:question: I would like to know whether einsum and conv2d are equivalent in my scenario.

:bulb: The reason of implementing with torch.einsum:

I have 16 images and 16 kernels and need to applying one kernel to one image to get one output. It is easy to directly get all outputs (for 16 images) with

torch.einsum('bchw,bchw->b', [img, weight])+bias.squeeze(-1)

The output:

tensor([ 4.6173e-05, -9.3411e-06, -8.0316e-05, -6.5993e-05,  1.3381e-04,
        -2.3025e-05, -1.3640e-06,  9.6504e-05,  2.1309e-06, -4.2717e-05,
         3.5023e-06,  3.2773e-05,  2.0304e-04, -2.4030e-05,  1.0894e-04,

Hi Luca!

For this computation, conv2d() and einsum() are equivalent. The
difference* you see is due to numerical round-off error (because the two
versions are performing the computations in different, but mathematically
equivalent orders).

Your two results agree to seven decimal digits, which is what you would
expect with single-precision (float32) accuracy.

einsum() is a perfectly fine way to perform your computation. Note,
however, that you can use conv2d() for this if you merge your β€œbatch”
and channels dimensions together and use conv2d()'s groups feature:

>>> import torch
>>> print (torch.__version__)
>>> _ = torch.manual_seed(1)
>>> weight = torch.normal(10**-5, 10**-6, (16, 3, 4, 4))
>>> bias = torch.normal(10**-5, 10**-6, (16,1))
>>> img = torch.randn(16, 3, 4, 4)
>>> result = torch.einsum('bchw,bchw->b', [img, weight])+bias.squeeze(-1)
>>> resultB = torch.nn.functional.conv2d (img.reshape (48, 4, 4), weight, bias.squeeze(), groups = 16).squeeze()
>>> torch.allclose (result, resultB)      # equal up to round-off error
>>> torch.equal (result[0], resultB[0])   # but first elements happen to be equal
>>> resultB
tensor([ 4.6173e-05, -9.3411e-06, -8.0316e-05, -6.5993e-05,  1.3381e-04,
        -2.3025e-05, -1.3640e-06,  9.6504e-05,  2.1309e-06, -4.2717e-05,
         3.5023e-06,  3.2773e-05,  2.0304e-04, -2.4030e-05,  1.0894e-04,

*) Oddly enough, I get the same results for conv2d() and einsum() when
I run your code. This happens on a couple of different versions of pytorch
(including 1.13.0) on both windows and linux.


K. Frank

1 Like

Hi Frank,

Thank you so much for your vivid and clear explanation!

FYI, I am using macOS with an M1 chip with torch = 1.12.1.

Best wishes,