Weird matmul() behavior

Can someone explain why this in the code below ?
as is with slen=24, dim=8, I have out=out2

if I change slen to 25 or more, or dim to 9 or more, then it does not work, out and out2 are not equal.
can’t get why.

import torch
import time

def tile(x, count):
“”"
Tiles x on dimension 0 count times.
“”"
out_size = list(x.size())
out_size[0] *= count
batch = x.size(0)
x = x.view(batch, -1)
.transpose(0, 1)
.repeat(count, 1)
.transpose(0, 1)
.contiguous()
.view(*out_size)
return x

batch_size = 16
beam_size = 2

slen = 24
dim = 8

enc_out = torch.randn(batch_size, slen, dim)
dec_in = torch.randn(batch_size * beam_size, 1, dim)

tiled_enc = tile(enc_out, beam_size)

regular matmul

out = torch.matmul(dec_in, tiled_enc.transpose(2, 1))

beamed matmul

inp1 = dec_in[:, 0, :].unfold(0, beam_size, beam_size).transpose(2, 1)
inp2 = tiled_enc.unfold(0, beam_size, beam_size)[:, :, :, 0]
out2 = torch.matmul(inp1, inp2.transpose(2,1))
out2 = out2.view(batch_size*beam_size, 1, -1)

print(torch.equal(out, out2))

Using torch.equal with floating point numbers is not a good idea due to the limited floating point precision and the expected small errors caused by a different order or operations.
If I change slen to 25 and dim to 9 I see an .abs().max() error of tensor(9.5367e-07) which is expected for float32.

1 Like

real life settings batch 16, beam 5, slen61, dim 512 gives me 1.5e-5 which is still ok.

many thanks.