Pytorch different result when using `torch.matmul` and `for-loop` to pass input through linear layers

Hi All,

This is my first post here. I have been at this for about two days now. I am working on a model that takes on an input x and passes it through several linear layers, and concatenates the results across the dimension of the number of linear layers as follows:

num_heads = 4
n, c, h, w = 2, 32, 64, 64
x = torch.rand((n*c*h*w)).reshape(n,c,h,w)

x = x.view(n, c, h*w).transpose(1, 2)   # n, h*w, c
heads_q = []
for i in range(num_heads):
    l = nn.Linear(c, c // num_heads, bias=False)
    l.weight.requires_grad_(False)
    heads_q.append(l)

heads_q_res = torch.cat([
        h(x).unsqueeze(0) for h in heads_q
    ], dim=0).detach()  # num_heads, n, h*w, c//num_heads
heads_q_res = heads_q_res.transpose(2,3)    # num_heads, n, c//num_heads, h*w

I want to get the same results as heads_q_res without having to use a for-loop to pass the input x through each of the linear layers. Thus, I have modified the above code as follows:

num_heads = 4
n, c, h, w = 2, 32, 64, 64
x = torch.rand((n*c*h*w)).reshape(n,c,h,w)

x = x.view(n, c, h*w).transpose(1, 2)   # n, h*w, c

heads_q = []
Wqs = []
for i in range(num_heads):
    l = nn.Linear(c, c // num_heads, bias=False)
    l.weight.requires_grad_(False)
    heads_q.append(l)
    Wqs.append(l.weight.unsqueeze(0).unsqueeze(0))
Wqs = torch.cat(Wqs, dim=0)  # num_heads, 1, c//num_heads, c
print(f'Wqs.shape: {Wqs.shape}')


for i in range(num_heads):
    assert torch.all(Wqs[i] == heads_q[i].weight.unsqueeze(0))
print('All weights check out.')

heads_q_res = torch.cat([
        h(x).unsqueeze(0) for h in heads_q
    ], dim=0).detach()  # num_heads, n, h*w, c//num_heads
heads_q_res = heads_q_res.transpose(2,3)    # num_heads, n, c//num_heads, h*w

_stacked = torch.matmul(Wqs, x.transpose(1,2))    # num_heads, n, c//num_heads, h*w
assert (torch.allclose(heads_q_res, _stacked)), f'\n{torch.abs(heads_q_res-_stacked)}'

here, _stacked represents the operation of passing the input x through all the linear layers at the same time. But unfortunately, the assertion statement in the last line of the above code block throws an assertion error:

_stacked.shape:    torch.Size([4, 2, 8, 4096])
heads_q_res.shape: torch.Size([4, 2, 8, 4096])
Traceback (most recent call last):
  File "f.py", line 54, in <module>
    assert (torch.allclose(heads_q_res, _stacked)), f'\n{torch.abs(heads_q_res-_stacked)}'
AssertionError:
tensor([[[[0.0000e+00, 7.4506e-09, 5.9605e-08,  ..., 0.0000e+00,
           2.9802e-08, 0.0000e+00],
          [0.0000e+00, 2.9802e-08, 8.9407e-08,  ..., 0.0000e+00,
           0.0000e+00, 0.0000e+00],
          [1.1176e-08, 1.4901e-08, 1.1921e-07,  ..., 3.7253e-09,
           0.0000e+00, 0.0000e+00],
          ...,
          [2.9802e-08, 3.7253e-09, 1.4901e-08,  ..., 7.4506e-09,
           0.0000e+00, 1.4901e-08],
          [0.0000e+00, 5.9605e-08, 0.0000e+00,  ..., 5.9605e-08,
           0.0000e+00, 0.0000e+00],
          [2.9802e-08, 0.0000e+00, 5.9605e-08,  ..., 1.4901e-08,
           0.0000e+00, 0.0000e+00]],

         [[1.4901e-08, 2.2352e-08, 5.9605e-08,  ..., 4.4703e-08,
           2.9802e-08, 0.0000e+00],
          [2.6077e-08, 1.4901e-08, 2.9802e-08,  ..., 2.9802e-08,
           2.9802e-08, 1.4901e-08],
          [0.0000e+00, 2.9802e-08, 0.0000e+00,  ..., 1.4901e-08,
           2.9802e-08, 5.9605e-08],
          ...,
          [2.9802e-08, 4.4703e-08, 2.9802e-08,  ..., 1.4901e-08,
           4.4703e-08, 3.3528e-08],
          [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,
           0.0000e+00, 0.0000e+00],
          [1.4901e-08, 1.4901e-08, 7.4506e-09,  ..., 0.0000e+00,
           0.0000e+00, 0.0000e+00]]],


        [[[1.3039e-08, 2.9802e-08, 2.9802e-08,  ..., 3.7253e-08,
           1.4901e-08, 2.2352e-08],
          [0.0000e+00, 8.9407e-08, 0.0000e+00,  ..., 3.1665e-08,
           6.7055e-08, 2.9802e-08],
          [0.0000e+00, 2.9802e-08, 0.0000e+00,  ..., 0.0000e+00,
           0.0000e+00, 2.9802e-08],
          ...,
          [2.9802e-08, 0.0000e+00, 1.1921e-07,  ..., 5.9605e-08,
           2.9802e-08, 5.9605e-08],
          [0.0000e+00, 4.4703e-08, 0.0000e+00,  ..., 0.0000e+00,
           0.0000e+00, 0.0000e+00],
          [5.9605e-08, 2.9802e-08, 1.4901e-08,  ..., 5.9605e-08,
           0.0000e+00, 0.0000e+00]],

         [[3.7253e-09, 7.4506e-09, 5.9605e-08,  ..., 3.6554e-08,
           5.2154e-08, 0.0000e+00],
          [4.4703e-08, 1.4901e-08, 0.0000e+00,  ..., 2.9802e-08,
           1.4901e-08, 2.9802e-08],
          [2.2352e-08, 7.4506e-09, 0.0000e+00,  ..., 2.9802e-08,
           2.9802e-08, 5.2154e-08],
          ...,
          [0.0000e+00, 2.9802e-08, 2.9802e-08,  ..., 5.9605e-08,
           5.9605e-08, 0.0000e+00],
          [0.0000e+00, 0.0000e+00, 5.9605e-08,  ..., 5.9605e-08,
           0.0000e+00, 0.0000e+00],
          [5.9605e-08, 0.0000e+00, 2.9802e-08,  ..., 5.9605e-08,
           0.0000e+00, 0.0000e+00]]],


        [[[0.0000e+00, 1.4901e-08, 0.0000e+00,  ..., 0.0000e+00,
           8.9407e-08, 2.9802e-08],
          [0.0000e+00, 5.9605e-08, 0.0000e+00,  ..., 0.0000e+00,
           2.9802e-08, 5.9605e-08],
          [2.9802e-08, 2.9802e-08, 1.1921e-07,  ..., 5.9605e-08,
           0.0000e+00, 5.9605e-08],
          ...,
          [1.4901e-08, 1.4901e-08, 0.0000e+00,  ..., 2.9802e-08,
           7.4506e-09, 1.4901e-08],
          [0.0000e+00, 2.2352e-08, 0.0000e+00,  ..., 2.9802e-08,
           0.0000e+00, 0.0000e+00],
          [2.9802e-08, 0.0000e+00, 1.4901e-08,  ..., 1.4901e-08,
           0.0000e+00, 0.0000e+00]],

         [[0.0000e+00, 5.2154e-08, 0.0000e+00,  ..., 8.9407e-08,
           1.1921e-07, 5.9605e-08],
          [2.2352e-08, 0.0000e+00, 1.4901e-08,  ..., 2.9802e-08,
           0.0000e+00, 0.0000e+00],
          [0.0000e+00, 5.9605e-08, 5.9605e-08,  ..., 0.0000e+00,
           5.9605e-08, 0.0000e+00],
          ...,
          [1.4901e-08, 4.4703e-08, 1.1176e-08,  ..., 1.4901e-08,
           1.4901e-08, 0.0000e+00],
          [2.9802e-08, 0.0000e+00, 0.0000e+00,  ..., 2.9802e-08,
           0.0000e+00, 0.0000e+00],
          [2.9802e-08, 1.4901e-08, 1.4901e-08,  ..., 0.0000e+00,
           0.0000e+00, 0.0000e+00]]],


        [[[7.4506e-09, 5.9605e-08, 2.2352e-08,  ..., 1.1176e-08,
           4.4703e-08, 1.4901e-08],
          [0.0000e+00, 2.9802e-08, 0.0000e+00,  ..., 0.0000e+00,
           5.2154e-08, 4.4703e-08],
          [2.9802e-08, 4.4703e-08, 5.9605e-08,  ..., 0.0000e+00,
           2.9802e-08, 5.9605e-08],
          ...,
          [2.9802e-08, 0.0000e+00, 4.0978e-08,  ..., 2.9802e-08,
           1.4901e-08, 2.9802e-08],
          [2.9802e-08, 0.0000e+00, 0.0000e+00,  ..., 2.9802e-08,
           0.0000e+00, 0.0000e+00],
          [2.9802e-08, 0.0000e+00, 5.9605e-08,  ..., 2.9802e-08,
           0.0000e+00, 0.0000e+00]],

         [[2.9802e-08, 7.4506e-09, 7.4506e-08,  ..., 4.4703e-08,
           5.9605e-08, 4.4703e-08],
          [2.9802e-08, 0.0000e+00, 1.4901e-08,  ..., 0.0000e+00,
           1.4901e-08, 5.9605e-08],
          [5.9605e-08, 2.9802e-08, 2.9802e-08,  ..., 3.6322e-08,
           0.0000e+00, 2.9802e-08],
          ...,
          [4.4703e-08, 7.4506e-09, 1.8626e-08,  ..., 1.4901e-08,
           2.6077e-08, 2.2352e-08],
          [0.0000e+00, 1.4901e-08, 5.9605e-08,  ..., 0.0000e+00,
           0.0000e+00, 0.0000e+00],
          [2.9802e-08, 2.9802e-08, 2.9802e-08,  ..., 2.9802e-08,
           0.0000e+00, 0.0000e+00]]]])

I have also tried replacing

_stacked = torch.matmul(Wqs, x.transpose(1,2))    # num_heads, n, c//num_heads, h*w

with

_stacked = torch.matmul(x.unsqueeze(1), Wqs.squeeze(1).transpose(1,2)).permute(1, 0, 3, 2)  # num_heads, n, c//num_heads, h*w

but to no avail.

However, I do not get this error when I initialize x as x = torch.ones((n*c*h*w)).reshape(n,c,h,w). I am completely lost here, and have no clue on how to fix this error. I would like to be able to pass all the assertions without relying on the initializer for x, any help would be appreciated.

Thank you.

you should do
_stacked = torch.matmul(x, Wqs.transpose(3, 2)).transpose(3, 2)
instead of what you have done
please refer to torch.linear to perform the same operation for _stacked

hi @yuri using _stacked = torch.matmul(x, Wqs.transpose(3, 2)).transpose(3, 2) also throws an assertion error. I am beginning to wonder if it might be an issue with numerical precision as the absolute value of the errors are in the range of 10^-7, some are even an order of magnitude lower. Either that or matmul might be doing something undocumented.

please be aware that the two operations are not the same. the second one is equivalent to passing through a linear layer. the one you posted is not
check (heads_q_res -_stacked) 10-8 precision should be fine

you can start by x = 10000*x to better compare both operations

In my understanding, both operations achieve the same result. No?

Assuming A = Wqs, B = x, C = _stacked

This says C = A.(B^T).

The change that @yuri suggested does C= (B.(A^T))^T = A.(B^T).

Just to add to the answer, by default, pytorch creates float tensors (which have ~7 digit precision during operations).
If you would like to have more precision, convert the data to double before performing the operation.

I don’t think so since this is a ‘‘partial’’ transpose on the 4th and 3rd dimension of the tensor
check
torch.matmul(Wqs, x.transpose(1,2)) - torch.matmul(x, Wqs.transpose(3, 2)).transpose(3, 2)
is not a zero matrix

Probably, check with a tolerance (atol in torch.allclose()) considering it is a floating-point operation, (or) print it and see?