Hi All,
This is my first post here. I have been at this for about two days now. I am working on a model that takes on an input x
and passes it through several linear layers, and concatenates the results across the dimension of the number of linear layers as follows:
num_heads = 4
n, c, h, w = 2, 32, 64, 64
x = torch.rand((n*c*h*w)).reshape(n,c,h,w)
x = x.view(n, c, h*w).transpose(1, 2) # n, h*w, c
heads_q = []
for i in range(num_heads):
l = nn.Linear(c, c // num_heads, bias=False)
l.weight.requires_grad_(False)
heads_q.append(l)
heads_q_res = torch.cat([
h(x).unsqueeze(0) for h in heads_q
], dim=0).detach() # num_heads, n, h*w, c//num_heads
heads_q_res = heads_q_res.transpose(2,3) # num_heads, n, c//num_heads, h*w
I want to get the same results as heads_q_res
without having to use a for-loop to pass the input x
through each of the linear layers. Thus, I have modified the above code as follows:
num_heads = 4
n, c, h, w = 2, 32, 64, 64
x = torch.rand((n*c*h*w)).reshape(n,c,h,w)
x = x.view(n, c, h*w).transpose(1, 2) # n, h*w, c
heads_q = []
Wqs = []
for i in range(num_heads):
l = nn.Linear(c, c // num_heads, bias=False)
l.weight.requires_grad_(False)
heads_q.append(l)
Wqs.append(l.weight.unsqueeze(0).unsqueeze(0))
Wqs = torch.cat(Wqs, dim=0) # num_heads, 1, c//num_heads, c
print(f'Wqs.shape: {Wqs.shape}')
for i in range(num_heads):
assert torch.all(Wqs[i] == heads_q[i].weight.unsqueeze(0))
print('All weights check out.')
heads_q_res = torch.cat([
h(x).unsqueeze(0) for h in heads_q
], dim=0).detach() # num_heads, n, h*w, c//num_heads
heads_q_res = heads_q_res.transpose(2,3) # num_heads, n, c//num_heads, h*w
_stacked = torch.matmul(Wqs, x.transpose(1,2)) # num_heads, n, c//num_heads, h*w
assert (torch.allclose(heads_q_res, _stacked)), f'\n{torch.abs(heads_q_res-_stacked)}'
here, _stacked
represents the operation of passing the input x
through all the linear layers at the same time. But unfortunately, the assertion statement in the last line of the above code block throws an assertion error:
_stacked.shape: torch.Size([4, 2, 8, 4096])
heads_q_res.shape: torch.Size([4, 2, 8, 4096])
Traceback (most recent call last):
File "f.py", line 54, in <module>
assert (torch.allclose(heads_q_res, _stacked)), f'\n{torch.abs(heads_q_res-_stacked)}'
AssertionError:
tensor([[[[0.0000e+00, 7.4506e-09, 5.9605e-08, ..., 0.0000e+00,
2.9802e-08, 0.0000e+00],
[0.0000e+00, 2.9802e-08, 8.9407e-08, ..., 0.0000e+00,
0.0000e+00, 0.0000e+00],
[1.1176e-08, 1.4901e-08, 1.1921e-07, ..., 3.7253e-09,
0.0000e+00, 0.0000e+00],
...,
[2.9802e-08, 3.7253e-09, 1.4901e-08, ..., 7.4506e-09,
0.0000e+00, 1.4901e-08],
[0.0000e+00, 5.9605e-08, 0.0000e+00, ..., 5.9605e-08,
0.0000e+00, 0.0000e+00],
[2.9802e-08, 0.0000e+00, 5.9605e-08, ..., 1.4901e-08,
0.0000e+00, 0.0000e+00]],
[[1.4901e-08, 2.2352e-08, 5.9605e-08, ..., 4.4703e-08,
2.9802e-08, 0.0000e+00],
[2.6077e-08, 1.4901e-08, 2.9802e-08, ..., 2.9802e-08,
2.9802e-08, 1.4901e-08],
[0.0000e+00, 2.9802e-08, 0.0000e+00, ..., 1.4901e-08,
2.9802e-08, 5.9605e-08],
...,
[2.9802e-08, 4.4703e-08, 2.9802e-08, ..., 1.4901e-08,
4.4703e-08, 3.3528e-08],
[0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,
0.0000e+00, 0.0000e+00],
[1.4901e-08, 1.4901e-08, 7.4506e-09, ..., 0.0000e+00,
0.0000e+00, 0.0000e+00]]],
[[[1.3039e-08, 2.9802e-08, 2.9802e-08, ..., 3.7253e-08,
1.4901e-08, 2.2352e-08],
[0.0000e+00, 8.9407e-08, 0.0000e+00, ..., 3.1665e-08,
6.7055e-08, 2.9802e-08],
[0.0000e+00, 2.9802e-08, 0.0000e+00, ..., 0.0000e+00,
0.0000e+00, 2.9802e-08],
...,
[2.9802e-08, 0.0000e+00, 1.1921e-07, ..., 5.9605e-08,
2.9802e-08, 5.9605e-08],
[0.0000e+00, 4.4703e-08, 0.0000e+00, ..., 0.0000e+00,
0.0000e+00, 0.0000e+00],
[5.9605e-08, 2.9802e-08, 1.4901e-08, ..., 5.9605e-08,
0.0000e+00, 0.0000e+00]],
[[3.7253e-09, 7.4506e-09, 5.9605e-08, ..., 3.6554e-08,
5.2154e-08, 0.0000e+00],
[4.4703e-08, 1.4901e-08, 0.0000e+00, ..., 2.9802e-08,
1.4901e-08, 2.9802e-08],
[2.2352e-08, 7.4506e-09, 0.0000e+00, ..., 2.9802e-08,
2.9802e-08, 5.2154e-08],
...,
[0.0000e+00, 2.9802e-08, 2.9802e-08, ..., 5.9605e-08,
5.9605e-08, 0.0000e+00],
[0.0000e+00, 0.0000e+00, 5.9605e-08, ..., 5.9605e-08,
0.0000e+00, 0.0000e+00],
[5.9605e-08, 0.0000e+00, 2.9802e-08, ..., 5.9605e-08,
0.0000e+00, 0.0000e+00]]],
[[[0.0000e+00, 1.4901e-08, 0.0000e+00, ..., 0.0000e+00,
8.9407e-08, 2.9802e-08],
[0.0000e+00, 5.9605e-08, 0.0000e+00, ..., 0.0000e+00,
2.9802e-08, 5.9605e-08],
[2.9802e-08, 2.9802e-08, 1.1921e-07, ..., 5.9605e-08,
0.0000e+00, 5.9605e-08],
...,
[1.4901e-08, 1.4901e-08, 0.0000e+00, ..., 2.9802e-08,
7.4506e-09, 1.4901e-08],
[0.0000e+00, 2.2352e-08, 0.0000e+00, ..., 2.9802e-08,
0.0000e+00, 0.0000e+00],
[2.9802e-08, 0.0000e+00, 1.4901e-08, ..., 1.4901e-08,
0.0000e+00, 0.0000e+00]],
[[0.0000e+00, 5.2154e-08, 0.0000e+00, ..., 8.9407e-08,
1.1921e-07, 5.9605e-08],
[2.2352e-08, 0.0000e+00, 1.4901e-08, ..., 2.9802e-08,
0.0000e+00, 0.0000e+00],
[0.0000e+00, 5.9605e-08, 5.9605e-08, ..., 0.0000e+00,
5.9605e-08, 0.0000e+00],
...,
[1.4901e-08, 4.4703e-08, 1.1176e-08, ..., 1.4901e-08,
1.4901e-08, 0.0000e+00],
[2.9802e-08, 0.0000e+00, 0.0000e+00, ..., 2.9802e-08,
0.0000e+00, 0.0000e+00],
[2.9802e-08, 1.4901e-08, 1.4901e-08, ..., 0.0000e+00,
0.0000e+00, 0.0000e+00]]],
[[[7.4506e-09, 5.9605e-08, 2.2352e-08, ..., 1.1176e-08,
4.4703e-08, 1.4901e-08],
[0.0000e+00, 2.9802e-08, 0.0000e+00, ..., 0.0000e+00,
5.2154e-08, 4.4703e-08],
[2.9802e-08, 4.4703e-08, 5.9605e-08, ..., 0.0000e+00,
2.9802e-08, 5.9605e-08],
...,
[2.9802e-08, 0.0000e+00, 4.0978e-08, ..., 2.9802e-08,
1.4901e-08, 2.9802e-08],
[2.9802e-08, 0.0000e+00, 0.0000e+00, ..., 2.9802e-08,
0.0000e+00, 0.0000e+00],
[2.9802e-08, 0.0000e+00, 5.9605e-08, ..., 2.9802e-08,
0.0000e+00, 0.0000e+00]],
[[2.9802e-08, 7.4506e-09, 7.4506e-08, ..., 4.4703e-08,
5.9605e-08, 4.4703e-08],
[2.9802e-08, 0.0000e+00, 1.4901e-08, ..., 0.0000e+00,
1.4901e-08, 5.9605e-08],
[5.9605e-08, 2.9802e-08, 2.9802e-08, ..., 3.6322e-08,
0.0000e+00, 2.9802e-08],
...,
[4.4703e-08, 7.4506e-09, 1.8626e-08, ..., 1.4901e-08,
2.6077e-08, 2.2352e-08],
[0.0000e+00, 1.4901e-08, 5.9605e-08, ..., 0.0000e+00,
0.0000e+00, 0.0000e+00],
[2.9802e-08, 2.9802e-08, 2.9802e-08, ..., 2.9802e-08,
0.0000e+00, 0.0000e+00]]]])
I have also tried replacing
_stacked = torch.matmul(Wqs, x.transpose(1,2)) # num_heads, n, c//num_heads, h*w
with
_stacked = torch.matmul(x.unsqueeze(1), Wqs.squeeze(1).transpose(1,2)).permute(1, 0, 3, 2) # num_heads, n, c//num_heads, h*w
but to no avail.
However, I do not get this error when I initialize x
as x = torch.ones((n*c*h*w)).reshape(n,c,h,w)
. I am completely lost here, and have no clue on how to fix this error. I would like to be able to pass all the assertions without relying on the initializer for x
, any help would be appreciated.
Thank you.