I am trying to multiply a 2 dimensional matrix with a 4 dimensional tensor, but torch.bmm doesn’t seem to allow 4-dim tensors.
This is the code I am using:
batch_size = 40 max_len = 96 logits = torch.randn(batch_size, 3) outs = torch.randn(batch_size, 3, max_len, 2) # Shape of desired output: (batch_size, max_len, 2) result = torch.bmm(logits, outs)
The error is:
RuntimeError: Expected 3-dimensional tensor, but got 4-dimensional tensor for argument #1 'batch1' (while checking arguments for bmm)
Is this possible any way?