What's the difference between batched matrix multiplication and multiplying each matrix in batch respectively?

As the tile says, I want to know what the difference between batched matrix multiplication and multiplying each matrix in batch respectively.

In the intermediate step of my network, I get a tensor x with shape [B, N, C] and a tensor y with shape [B, C, N].

The first way:

masked_x = x[mask].view(B, -1, C)
f = torch.matmul(masked_x, y)

The second way:

masked_x = [x[i][mask[i]] for i in range(B)]
f = [torch.matmul(masked_x[i], y[i]) for i in range(B)]

mask has a shape [B, N] with bool values. Each mask[i] has the same number of True values.

Then, when I trained the network, the first way took up more GPU memorys than the second way during loss.backward(). What’s the reason?

The first method is doing all matmuls at once whereas the second method is doing 1 at a time via list comprehension. The second method will be slower than the 1st method but it’ll require less memory.

Make sure your mask has requires_grad set to False that’ll free up some additional memory if you’re trying to minimize the memory footprint.

Thanks for reply! So the second method is just using time for space, and two methods will get exactly the same behavior?

They’ll give the same behaviour but you should double check and compare the two Tenors via torch.allclose

I’ll check that. Thanks a lot for your answers!

1 Like