Can someone explain why this in the code below ?
as is with slen=24, dim=8, I have out=out2
if I change slen to 25 or more, or dim to 9 or more, then it does not work, out and out2 are not equal.
can’t get why.
import torch
import time
def tile(x, count):
“”"
Tiles x on dimension 0 count times.
“”"
out_size = list(x.size())
out_size[0] *= count
batch = x.size(0)
x = x.view(batch, -1)
.transpose(0, 1)
.repeat(count, 1)
.transpose(0, 1)
.contiguous()
.view(*out_size)
return x
batch_size = 16
beam_size = 2
slen = 24
dim = 8
enc_out = torch.randn(batch_size, slen, dim)
dec_in = torch.randn(batch_size * beam_size, 1, dim)
tiled_enc = tile(enc_out, beam_size)
regular matmul
out = torch.matmul(dec_in, tiled_enc.transpose(2, 1))
beamed matmul
inp1 = dec_in[:, 0, :].unfold(0, beam_size, beam_size).transpose(2, 1)
inp2 = tiled_enc.unfold(0, beam_size, beam_size)[:, :, :, 0]
out2 = torch.matmul(inp1, inp2.transpose(2,1))
out2 = out2.view(batch_size*beam_size, 1, -1)
print(torch.equal(out, out2))