As the tile says, I want to know what the difference between batched matrix multiplication and multiplying each matrix in batch respectively.
In the intermediate step of my network, I get a tensor x with shape [B, N, C] and a tensor y with shape [B, C, N].
The first way:
masked_x = x[mask].view(B, -1, C) f = torch.matmul(masked_x, y)
The second way:
masked_x = [x[i][mask[i]] for i in range(B)] f = [torch.matmul(masked_x[i], y[i]) for i in range(B)]
mask has a shape [B, N] with bool values. Each mask[i] has the same number of True values.
Then, when I trained the network, the first way took up more GPU memorys than the second way during loss.backward(). What’s the reason?