Thank you, I forgot that haha.
I have another issue now, I get this:
RuntimeError: Input type (torch.cuda.FloatTensor) and weight type (torch.FloatTensor) should be the same
This is my complete code:
#initiated GoogleNet with last layer modified (LSTM follows)
class gN_changed(nn.Module):
def __init__(self, latent_dim = 512):
super(gN_changed, self).__init__()
self.model = torch.hub.load('pytorch/vision:v0.10.0', 'googlenet', pretrained=True)
#freeze paramters (trains faster and keeps weight values of ImageNet)
for params in self.model.parameters():
params.requires_grad = False
#change last fully completerd layer
self.model.fc = nn.Linear(self.model.fc.in_features, latent_dim)
def forward(self, x):
return self.model(x)
class Lstm(nn.Module):
def __init__(self, latent_dim = 512, hidden_size = 256, lstm_layers = 2, bidirectional = True):
super(Lstm, self).__init__()
self.Lstm = nn.LSTM(latent_dim, hidden_size=hidden_size, num_layers=lstm_layers, batch_first=True, bidirectional=bidirectional)
self.hidden_state = None
def reset_hidden_state(self):
self.hidden_state = None
def forward(self,x):
output, self.hidden_state = self.Lstm(x, self.hidden_state)
return output
class ConvLstm(nn.Module):
def __init__(self, latent_dim = 512, hidden_size = 256, lstm_layers = 2, bidirectional = True, n_class = 10):
super(ConvLstm, self).__init__()
self.model = gN_changed(latent_dim)
self.Lstm = Lstm(latent_dim, hidden_size, lstm_layers, bidirectional)
self.output_layer = nn.Sequential(
nn.Linear(2 * hidden_size if bidirectional==True else hidden_size, n_class),
nn.Softmax(dim=-1)
)
def forward(self, x):
batch_size, timesteps, channel_x, h_x, w_x = x.shape
conv_input = x.view(batch_size * timesteps, channel_x, h_x, w_x)
conv_output = self.model.forward(conv_input)
lstm_input = conv_output.view(batch_size, timesteps, -1)
lstm_output = self.Lstm(lstm_input)
lstm_output = lstm_output[:, -1, :]
output = self.output_layer(lstm_output)
return output
And when I create the model, I make sure that both the model and the data are in the same device (cuda). I am not sure what goes wrong