Hi, I’m using Pytorch 1.7.1 and I observed the following:
In [9]: import torch
...: q = torch.randn(1, 1000, 128)
...: k = torch.randn(1, 1000, 128)
...: q1 = q[:, :500, :]
...: k1 = k[:, :500, :]
...: attn = torch.bmm(q, k.transpose(1,2))
...: attn1 = torch.bmm(q1, k1.transpose(1,2))
...:
...: A = attn[0,:500,:500]
...: B = attn1[0,:500,:500]
...:
...: print(A)
...: print(B)
...: print((A==B).all())
...: print(A.sum() == B.sum())
tensor([[ 13.7018, -6.4425, -10.8135, ..., -22.0151, -7.4277, 1.3627],
[ -1.8352, 1.0710, -7.1797, ..., -6.5084, -7.4944, 4.5227],
[ 9.8117, -20.0964, -1.4925, ..., 5.0464, -4.3286, 5.9336],
...,
[ -3.9975, -3.4578, 2.1641, ..., 0.7644, 9.4954, 3.1092],
[-17.6577, 17.8311, 10.5954, ..., -0.4581, 10.4234, 1.1901],
[-11.4489, 15.5743, 9.7824, ..., 0.8799, 16.4985, 20.6857]])
tensor([[ 13.7018, -6.4425, -10.8135, ..., -22.0151, -7.4277, 1.3627],
[ -1.8352, 1.0710, -7.1797, ..., -6.5084, -7.4944, 4.5227],
[ 9.8117, -20.0964, -1.4925, ..., 5.0464, -4.3286, 5.9336],
...,
[ -3.9975, -3.4578, 2.1641, ..., 0.7644, 9.4954, 3.1092],
[-17.6577, 17.8311, 10.5954, ..., -0.4581, 10.4234, 1.1901],
[-11.4489, 15.5743, 9.7824, ..., 0.8799, 16.4985, 20.6857]])
tensor(True)
tensor(False)
Can someone help explain why the discrepancy in the results?
Many thanks.