Hello everybody,
Recently, I am trying to implement a two-layer LSTM model and train it with my dataset.
This is my model implementation:
class LstmModel(nn.Module):
def __init__(self, device):
super(LstmModel, self).__init__()
self.lstm1 = nn.LSTM(
input_size=2048, hidden_size=1024, batch_first=True)
self.lstm2 = nn.LSTM(
input_size=1024, hidden_size=128, batch_first=True)
self.linear = nn.Linear(128, 1)
self.sigmoid = nn.Sigmoid()
self.device = device
def forward(self, x):
h0 = torch.zeros((1, 32, 1024)).to(self.device)
c0 = torch.zeros((1, 32, 1024)).to(self.device)
x, _ = self.lstm1(x, (h0, c0))
h1 = torch.zeros((1, 32, 128)).to(self.device)
c1 = torch.zeros((1, 32, 128)).to(self.device)
x, _ = self.lstm2(x, (h1, c1))
x = self.linear(x)
x = self.sigmoid(x)
return x
And this is my dataset implementation.
class FeaturesDataset(Dataset):
def __init__(self):
csvFile = pd.read_csv(
"E:\\Datasets\\VQA\\KoNVID_1k_LSTM_CNN\\Features_240\\raw_data.csv")
self.X = csvFile['flickr_id']
self.y = csvFile['mos']
def __len__(self):
return len(self.X)
def __getitem__(self, idx):
dist_path = 'E:\Datasets\VQA\KoNVID_1k_LSTM_CNN\Features_240'
video_name = str(int(self.X[idx])) + '.txt'
x = np.loadtxt(join(dist_path, video_name))
y = self.y[idx]
return torch.tensor(x), y
The following code is one row of my dataset.
The first element is my features, with the shape of 240 * 2048, which 240 is the length of my sequence and 2048 is the number of input features.
(tensor([[0.1705, 0.0173, 1.1980, ..., 0.0022, 1.1543, 0.0000],
[0.1560, 0.0206, 1.2292, ..., 0.0000, 1.1819, 0.0000],
[0.0830, 0.0018, 1.3242, ..., 0.0026, 1.3389, 0.0024],
...,
[0.0645, 0.0062, 1.3789, ..., 0.0031, 1.0773, 0.0000],
[0.0683, 0.0048, 1.3562, ..., 0.0025, 1.1218, 0.0000],
[0.0811, 0.0058, 1.4024, ..., 0.0017, 1.1745, 0.0000]],
dtype=torch.float64), 4.64)
(venv) PS C:\Users\HP\my-python-pro
The problem is while training the model, I get RuntimeError: Input and hidden tensors are not the same dtype, found input tensor with Double and hidden tensor with Float
.
Also, the batch size of my data is 32. Does anybody know the reason?