Hello, I am creating a RNN for binary classification. The goal is to look at binary arrays of length 60 in which arrays containing 2 or more consecutive 1s are not a part of the grammar (target = 0) and those that do not are a part of the grammar (target = 1). The test data is similar to the training data except that it is of length 80. In my model I attempted to make the batch_sizes = 3 and 4 for the training and test set respectively; however, I get the error 'Expected input batch_size (3) to match target batch_size (1)." I am not sure how to get this network to work. The goal is to use the binary arrays to predict either class 1 or 0 - so the target batch size should only be 1 no?
Here is the model class:
input_size = 1
batch = 3
sequence_length = 20
hidden_size = 128
num_classes = 2
num_layers = 1
class RNN(nn.Module):
def __init__(self, input_size, hidden_size, num_layers, output_size):
super(RNN, self).__init__()
self.hidden_size = hidden_size
self.num_layers = num_layers
self.rnn = nn.RNN(input_size, hidden_size, num_layers, batch_first = True)
#self.fc = nn.Linear(hidden_size*sequence_length, output_size)
self.fc = nn.Linear(self.hidden_size*sequence_length, output_size)
self.softmax = nn.LogSoftmax(dim=1)
def forward(self, x):
#x = torch.reshape(x, (batch, sequence_length, input_size))
# h0 = torch.zeros(self.num_layers, x.size(1), self.hidden_size).cuda()
h0 = torch.zeros(self.num_layers, batch, self.hidden_size).cuda()
#x = torch.unsqueeze(x, 0)
# print(x.size())
#forward propagation
out, _ = self.rnn(x, h0)
out = out.reshape(out.shape[0], -1)
out = self.fc(out)
output = self.softmax(out)
return output
The test definition:
tot_losses = []
tot_counter = [i * len(train_loader.dataset) for i in range(num_epochs + 1)]
def test(model, loader, batch = 3):
with torch.no_grad():
model.eval()
N = 0
tot_loss, correct = 0.0, 0.0
predictions = []
targets = []
for i, (data, target) in enumerate(loader):
data, target = data.cuda(), target.cuda()
if batch == 3:
data = torch.reshape(data, (batch, sequence_length, input_size))
else:
data = torch.reshape(data, (4, sequence_length, input_size))
#print(data.size())
output = model(data)
tot_loss += criterion(output, target).cpu().numpy()
pred = output.data.max(1, keepdim = True)[1]
targets.append(target.cpu())
predictions.append(pred.cpu())
correct += pred.eq(target.data.view_as(pred)).sum()
tot_loss /= len(test_loader.dataset)
tot_losses.append(tot_loss)
confusion_matrix = ConfusionMatrix(predictions, targets, len(loader))
return tot_loss, 100. * correct/len(loader.dataset), confusion_matrix
Training loop:
# train_losses = []
# train_counter = []
logdir = generate_unique_logpath(top_logdir, "RNN_Adam_IL20_LR00005")
print("Logging to {}".format(logdir))
# -> Prints out Logging to ./logs/linear_0
if not os.path.exists(logdir):
os.mkdir(logdir)
print("Before Training Validation Set Performance:")
val_loss, val_acc, confusion_M = test(model, val_loader)
print("\nValidation : Avg. Loss : {:.4f}, Accuracy : {:.2f}\n".format(val_loss, val_acc))
print("Validation Set Confusion Matrix: \n" + str(confusion_M))
print()
model_checkpoint = ModelCheckpoint(logdir + "/best_model.pt", model)
for epoch in range(num_epochs):
print("---------------------------------------------------------------------------")
for batch_idx, (data, target) in enumerate(train_loader):
data = data.cuda()
targets = target.cuda()
data = torch.reshape(data, (batch, sequence_length, input_size))
model.train()
output = model(data)
# if batch_idx % 50:
# print(output)
# print(targets)
loss = criterion(output, targets)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if batch_idx % 50 == 0:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(epoch, batch_idx * len(data), len(train_loader.dataset), 100. * batch_idx / len(train_loader), loss.item()))
train_losses.append(loss.item())
train_counter.append((batch_idx * batch_size_train) + ((epoch) * len(train_loader.dataset)))
val_loss, val_acc, confusion_M = test(model, val_loader, batch = 4)
model_checkpoint.update(val_loss)
print("\n Validation : Avg. Loss : {:.4f}, Accuracy : {:.2f}\n".format(val_loss, val_acc))
print("Validation Set Confusion Matrix: \n" + str(confusion_M))
print()
print("---------------------------------------------------------------------------")
error:
Logging to C:\Users\Daniel\OneDrive\Documents\Neural Networks Hw\Best Models Project\logs\RNN_Adam_IL20_LR00005_14
Before Training Validation Set Performance:
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
<ipython-input-23-359fc80a1bdb> in <module>
9
10 print("Before Training Validation Set Performance:")
---> 11 val_loss, val_acc, confusion_M = test(model, val_loader)
12 print("\nValidation : Avg. Loss : {:.4f}, Accuracy : {:.2f}\n".format(val_loss, val_acc))
13 print("Validation Set Confusion Matrix: \n" + str(confusion_M))
<ipython-input-21-ff2c743fda6d> in test(model, loader, batch)
19 #print(data.size())
20 output = model(data)
---> 21 tot_loss += criterion(output, target).cpu().numpy()
22 pred = output.data.max(1, keepdim = True)[1]
23 targets.append(target.cpu())
~\anaconda3\lib\site-packages\torch\nn\modules\module.py in __call__(self, *input, **kwargs)
530 result = self._slow_forward(*input, **kwargs)
531 else:
--> 532 result = self.forward(*input, **kwargs)
533 for hook in self._forward_hooks.values():
534 hook_result = hook(self, input, result)
~\anaconda3\lib\site-packages\torch\nn\modules\loss.py in forward(self, input, target)
202
203 def forward(self, input, target):
--> 204 return F.nll_loss(input, target, weight=self.weight, ignore_index=self.ignore_index, reduction=self.reduction)
205
206
~\anaconda3\lib\site-packages\torch\nn\functional.py in nll_loss(input, target, weight, size_average, ignore_index, reduce, reduction)
1834 if input.size(0) != target.size(0):
1835 raise ValueError('Expected input batch_size ({}) to match target batch_size ({}).'
-> 1836 .format(input.size(0), target.size(0)))
1837 if dim == 2:
1838 ret = torch._C._nn.nll_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index)
ValueError: Expected input batch_size (3) to match target batch_size (1).
Thank you!