Hi, I am new to pytorch and meet a problem using LSTM. Maybe it is because of my input shape, but I don’t know how to do.
Here is my code (some lines omitted)
class LSTMModel(nn.Module):
def __init__(self, in_dim, n_layer):
super(LSTMModel, self).__init__()
self.n_layer = n_layer
self.hidden_dim = in_dim
self.lstm = nn.LSTM(in_dim, self.hidden_dim, n_layer, batch_first=True)
def forward(self, x):
out, h = self.lstm(x)
return h[0]
class MyModel(nn.Module):
def __init__(self, config):
super(TATransEModel, self).__init__()
...
self.lstm = LSTMModel(self.embedding_size, 1)
..
def unroll(self, data, unroll_len = 4):
result = []
for i in range(len(data) - unroll_len):
result.append(data[i: i+unroll_len])
return result
def forward(self, pos_h, pos_t, pos_r, pos_tem, neg_h, neg_t, neg_r, neg_tem):
pos_h_e = self.ent_embeddings(pos_h)
pos_t_e = self.ent_embeddings(pos_t)
pos_r_e = self.rel_embeddings(pos_r)
pos_tem_e = []
for tem in pos_tem:
tem_e = []
for token in tem:
token_e = self.tem_embeddings(token)
tem_e.append(token_e)
pos_tem_e.append(tem_e)
pos_rseq_e = []
# add LSTM
pos_input_r = self.unroll(pos_r_e)
pos_hidden_r = self.lstm(pos_input_r)
The error message is:
Traceback (most recent call last):
File "train_new.py", line 237, in <module>
pos, neg = model(pos_h_batch, pos_t_batch, pos_r_batch, pos_time_batch, neg_h_batch, neg_t_batch, neg_r_batch, neg_time_batch)
File "/anaconda3/lib/python3.6/site-packages/torch/nn/modules/module.py", line 489, in __call__
result = self.forward(*input, **kwargs)
File "/Users/genius/dynamic-KG-basic/model.py", line 92, in forward
pos_hidden_r = self.lstm(pos_input_r)
File "/anaconda3/lib/python3.6/site-packages/torch/nn/modules/module.py", line 489, in __call__
result = self.forward(*input, **kwargs)
File "/Users/genius/dynamic-KG-basic/model.py", line 30, in forward
out, h = self.lstm(x)
File "/anaconda3/lib/python3.6/site-packages/torch/nn/modules/module.py", line 489, in __call__
result = self.forward(*input, **kwargs)
File "/anaconda3/lib/python3.6/site-packages/torch/nn/modules/rnn.py", line 165, in forward
max_batch_size = input.size(0) if self.batch_first else input.size(1)
AttributeError: 'list' object has no attribute 'size'
Thank you for having a look!