I am making a simple recurrent neural network architecture for CIFAR10 image classification. I am getting batch size mismatch error.
Data Preprocess
all_transforms = transforms.Compose([
transforms.Resize((32, 32)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.4914, 0.4822, 0.4465],
std=[0.2023, 0.1994, 0.2010])]
)
train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, transform=all_transforms, download=True)
test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, transform=all_transforms, download=True)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=512, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=512, shuffle=True)
Simple RNN
class SimpleRNN(nn.Module):
def __init__(self, input_dim, hidden_dim, num_layers, num_classes):
super(SimpleRNN, self).__init__()
self.num_layers = num_layers
self.hidden_dim = hidden_dim
self.rnn = nn.RNN(input_size=input_dim, hidden_size=hidden_dim, num_layers=num_layers, batch_first=True)
self.fc = nn.Linear(hidden_dim, num_classes) # hidden dimension is output dimension
def forward(self, x):
h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_dim).to(device)
out, _ = self.rnn(x, h0)
out = out[:, -1, :]
out = self.fc(out)
return out
x.shape--> torch.Size([1536, 32, 32])
h0.shape--> torch.Size([2, 1536, 128])
out.shape of rnn -> torch.Size([1536, 32, 128])
out = out[:, -1, :] --> torch.Size([1536, 128])
Training Loop
for epoch in range(epochs):
for i, (images, labels) in enumerate(train_loader):
images = images.reshape(-1, sequence_length, input_size).cuda()
labels = labels.cuda()
outputs = model(images)
loss = criterion(outputs, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
print("Epochs [{}/{}], Loss: {:4f}".format(epoch + 1, epochs, loss.item()))
input_size = 32
hidden_size = 128
num_layers = 2
num_classes = 10
sequence_length = 32
Traceback
Traceback (most recent call last):
File "/media/cvpr/CM_1/tutorials/pytorch_simple_rrn.py", line 62, in <module>
loss = criterion(outputs, labels)
File "/home/cvpr/anaconda3/envs/tutorials/lib/python3.8/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
result = self.forward(*input, **kwargs)
File "/home/cvpr/anaconda3/envs/tutorials/lib/python3.8/site-packages/torch/nn/modules/loss.py", line 961, in forward
return F.cross_entropy(input, target, weight=self.weight,
File "/home/cvpr/anaconda3/envs/tutorials/lib/python3.8/site-packages/torch/nn/functional.py", line 2468, in cross_entropy
return nll_loss(log_softmax(input, 1), target, weight, None, ignore_index, None, reduction)
File "/home/cvpr/anaconda3/envs/tutorials/lib/python3.8/site-packages/torch/nn/functional.py", line 2261, in nll_loss
raise ValueError('Expected input batch_size ({}) to match target batch_size ({}).'
ValueError: Expected input batch_size (1536) to match target batch_size (512).