Hi everyone ,
Before adding “nn.DataParallel” lines to below code I ran into “Error(s) in loading state_dict for EncoderCNN:” error and I read t solution to loading state_dict post to follow I am not sure if I add it right way. I couldn’t find a clear article how to do it could you check out my code and suggest a solution please?
import torch
import torch.nn as nn
import torchvision.models as models
from torch.autograd import Variable
class EncoderCNN(nn.Module):
def __init__(self, embed_size):
super(EncoderCNN, self).__init__()
resnet = models.resnet50(pretrained=True)
modules = list(resnet.children())[:-1]
self.resnet = nn.Sequential(*modules)
self.embed = nn.Linear(resnet.fc.in_features, embed_size)
self.saveWeights=nn.DataParallel
def forward(self, images):
features = self.resnet(images)
features = Variable(features.data)
features = features.view(features.size(0), -1)
features = self.embed(features)
return features
class DecoderRNN(nn.Module):
def __init__(self, embed_size, hidden_size, vocab_size, num_layers=2):
super(DecoderRNN, self).__init__()
self.embed = nn.Embedding(vocab_size, embed_size)
self.lstm = nn.LSTM(embed_size, hidden_size, num_layers, batch_first=True)
self.dropout = nn.Dropout(0.2)
self.linear = nn.Linear(hidden_size, vocab_size)
self.saveWeights=nn.DataParallel
def forward(self, features, captions):
captions = captions[:,:-1]
embeddings = self.embed(captions)
inputs = torch.cat((features.unsqueeze(1), embeddings), 1)
hiddens, _ = self.lstm(inputs)
outputs = self.linear(hiddens)
return outputs
def sample(self, inputs, states=None, max_len=20):
" accepts pre-processed image tensor (inputs) and returns predicted sentence (list of tensor ids of length max_len) "
sampled_ids = []
for i in range(max_len):
hiddens, states = self.lstm(inputs, states)
outputs = self.linear(hiddens.squeeze(1))
predicted = outputs.max(1)[1]
sampled_ids.append(predicted.data[0])
inputs = self.embed(predicted)
inputs = inputs.unsqueeze(1)
return sampled_ids
import torch.utils.data as data
import numpy as np
import os
import requests
import time
# Open the training log file.
f = open(log_file, 'w')
for epoch in range(1, num_epochs+1):
for i_step in range(1, int(total_step)+1):
# Randomly sample a caption length, and sample indices with that length.
indices = data_loader.dataset.get_train_indices()
# Create and assign a batch sampler to retrieve a batch with the sampled indices.
new_sampler = data.sampler.SubsetRandomSampler(indices=indices)
data_loader.batch_sampler.sampler = new_sampler
# Obtain the batch.
for batch in data_loader:
images, captions = batch[0], batch[1]
break
# Convert batch of images and captions to Pytorch Variable.
images = to_var(images, volatile=True)
captions = to_var(captions)
# Zero the gradients.
decoder.zero_grad()
encoder.zero_grad()
# Pass the inputs through the CNN-RNN model.
features = encoder(images)
outputs = decoder(features, captions)
# Calculate the batch loss.
loss = criterion(outputs.view(-1, vocab_size), captions.view(-1))
# Backward pass.
loss.backward()
# Update the parameters in the optimizer.
optimizer.step()
# Get training statistics.
stats = 'Epoch [%d/%d], Step [%d/%d], Loss: %.4f, Perplexity: %5.4f' % (epoch, num_epochs, i_step, total_step, loss.data[0], np.exp(loss.data[0]))
# Print training statistics (on same line).
print('\r' + stats + "\n")
sys.stdout.flush()
# Print training statistics to file.
f.write(stats + '\n')
f.flush()
# Print training statistics (on different line).
if i_step % print_every == 10:
print('\r' + stats)
# Save the weights.
if epoch % save_every == 0:
torch.save(decoder.state_dict(), os.path.join('./models', 'decoder-%d.pkl' % epoch))
torch.save(encoder.state_dict(), os.path.join('./models', 'encoder-%d.pkl' % epoch))
# Close the training log file.
f.close()