how to multi a (batch_size, 1, y) and (y, x) into (batch_size, 1, x)
the second tensor is not a batch, only have 2 dimention, I wanan multi the second tensor with each item of the first batch
torch.bmm(a, b.expand(batch_size, 1, x))
1 Like