Hi,

u= torch.randn(3, 1, 40, 64)

v= torch.randn(1, 3, 40, 64)

hawo can I compute the product of u and v.torch.bmm not work for this example.

Thanks,

Hi,

u= torch.randn(3, 1, 40, 64)

v= torch.randn(1, 3, 40, 64)

hawo can I compute the product of u and v.torch.bmm not work for this example.

Thanks,

What do you mean by product here? matrix-matrix multiplication? Or broadcasting + element-wise?

What is the output size you expect?? [3, 40, 64]? [1, 40 64]? [3, 3, 40]?

Let me guessâ€¦

Do you mean

```
u -> (A, B, batch_1, batch_2)
b -> (B, A, batch_1, batch_2)
result -> (A, A, batch_1, batch_2)
```

which `result[:, :, i, j]`

equals to `u[:, :, i, j] @ v[:, :, i, j]`

?

Just like a 4-dim bmm?

Hi@Eta_C

Exactly the result will be :

(A, A, batch_1, batch_2)

What do you mean by i and j?

If I understand correctlyâ€¦

Try this, you should build a small example to test its correctness.

```
import torch
u = torch.randn(3, 1, 40, 64)
v = torch.randn(1, 3, 40, 64)
_u = u.reshape(3, 1, -1)
_v = v.reshape(1, 3, -1)
_result = torch.bmm(_u.permute(2, 0, 1), _v.permute(2, 0, 1))
result = _result.permute(1, 2, 0).reshape(3, 3, 40, 64)
```

1 Like