jef
November 13, 2017, 10:26pm
1
I have batch data and want to dot
to the data. W
is trainable parameters.
How to dot
between batch data and weights?
hid_dim = 32
data = torch.randn(10, 2, 3, hid_dim)
data = tdata.view(10, 2*3, hid_dim)
W = torch.randn(hid_dim) # assume trainable parameters via nn.Parameter
result = torch.bmm(data, W).squeeze() # error, want (N, 6)
result = result.view(10, 2, 3) #
Update
This may look good.
hid_dim = 32
data = torch.randn(10, 2, 3, hid_dim)
data = tdata.view(10, 2*3, hid_dim)
W = torch.randn(hid_dim, 1) # assume trainable parameters via nn.Parameter
W = W.unsqueeze(0).expand(10, hid_dim, 1)
result = torch.bmm(data, W).squeeze() # error, want (N, 6)
result = result.view(10, 2, 3)
SimonW
(Simon Wang)
November 13, 2017, 10:39pm
2
First of all, all things should be wrapped into Variable for W
to be trainable. Secondly, W
has shape (32,)
, which is not multiply-able with tensors of shape (2 3)
. So I assume that W
is of shape (3, 2)
.
Then, you can use torch.bmm(data, W.unsqueeze(0).expand(10, 3, 2))
. Probably you don’t need unsqueeze
but I don’t have access to pytorch right now so you can check your self.
jef
November 13, 2017, 10:54pm
3
@SimonW
Thank you. I updated my post. How about my code in case of using 1 row W
.
SimonW
(Simon Wang)
November 13, 2017, 10:56pm
4
It still doesn’t make sense bmm (10, 2, 3)
and (10, 32, 1)
. What exactly should be multiplied with each matrix?
jef
November 13, 2017, 11:07pm
5
bmm of (10, 2*3, hid_dim) and (10, hid_dim, 1)
, not (10, 2, 3) and (10, 32, 1). Sorry for the confusing.
SimonW
(Simon Wang)
November 13, 2017, 11:12pm
6
Oh I see, sorry I missed the hid_dim
. Yeah, I think your code should work. Are you still seeing errors?
jef
November 14, 2017, 2:01am
7
No, thank you for your help!
I have tensors of size A: 32 x 4 x 1 and B: 4 x 2 . I want my output tensor to be of shape 32 x 2 x 1
Can you explain, how can I multiply A and B?
This should work:
a = torch.randn(32, 4, 1)
b = torch.randn(4, 2)
c = torch.matmul(b.unsqueeze(0).permute(0, 2, 1), a)
print(c.shape)
> torch.Size([32, 2, 1])
1 Like
Ravi_Raja
(Ravi Raja)
May 27, 2021, 6:54am
10
Thanks a lot man! It worked.