Difference between `matmul` broadcast and `bmm` on computational graph

Hello,

I’m performing a matrix multiplication using matmul function:

hidden_size = 8
batch_size  = 5

W = Var(hidden_size,hidden_size)
emb = torch.tensor(torch.randn(batch_size,12,hidden_size))

res =  emb.matmul(W).bmm(emb.transpose(1, 2))

The first matmul fuction just broadcast the operation within the batch dimensions and the result is as expected.

Another way of accomplishing this is using bmm:

Wb = W.expand(batch_size,  -1,-1)
resb = emb.bmm(Wb).bmm(emb.transpose(1, 2))

both results are equal
torch.norm(res2-res)
but while training my model, the loss has different training values using the first or the second method.

I’m guessing that the two methods generate different computational graphs.
Is it?
If so, what is the correct way of doing this?

Thank you in advance,
Ricardo

Hi,

These two should be equivalent even if they define different computational graphs (potentially doing broadcasting in a different way).
That means that, during training, because of floating point precision, they can end up giving noticeably different results as errors are amplified by the training.

Do you have a small code sample to reproduce the behavior you present?

2 Likes

Hello @albanD!

Thank you for your quick reply.

Unfortunately I do not have a simple code showing the behavior. I only have a working version of a more developed model.
I will try do make one and post it here as soon as possible.

Mean while, in your opinion, what function should normally be used in this case?

Best regards,
Ricardo

Both will be correct. You can see changing from one to the other having the same effect as changing the random seed that you set at the beginning of your script: all the numbers you will get will be different but if your model is robust, both should converge to a similar solution in terms of performance.

1 Like

What is the reason behind having both matmul and bmm potentional doing the exact same task? What exactly differ from one another? Is there any advantage in using one over another?

1 Like

bmm is the simple batch matrix matrix multiply.
matmul is more general as depending on the inputs, it can correspond to dot, mm or bmm.

3 Likes

Hi! Does matmul have higher fp precision than bmm? Thanks!

Hi,

They both have the same precision which will depend on the dtype of the input you give it.

ok thanks for response.