I’m trying to build a simple MNIST Model and this is what I’ve built -
training_loader = DataLoader(training_dataset, 128, shuffle = True)
validation_loader = DataLoader(validation_dataset, 128)
class mnistmodel(nn.Module):
def __init__(self):
super().__init__()
self.linear1 = nn.Linear(784, 10)
self.linear2 = nn.Linear(10, 5)
self.linear3 = nn.Linear(5, 10)
def forward(self, xb):
xb.reshape(-1, 784)
predicted = F.relu(self.linear1(xb))
predicted.reshape(-1, 10)
predicted = F.relu(self.linear2(predicted))
predicted.reshape(-1, 5)
predicted = self.linear3(predicted)
return predicted
def training_step(self, batch):
images, labels = batch
predicted = self(images)
loss = F.cross_entropy(predicted, labels)
return loss
def validation_step(self, batch):
images, labels = batch
predicted = self(images)
loss = F.cross_entropy(predicted, labels)
_, preds = torch.max(predicted, dim=1)
accuracy = torch.tensor(torch.sum(preds == labels).item() / len(preds))
return {'validation_loss': loss, 'validation_accuracy': accuracy}
def validation_epoch_end(self, outputs):
batch_losses = [x['validation_loss'] for x in outputs]
epoch_loss = torch.stack(batch_losses).mean()
batch_accs = [x['validation_acc'] for x in outputs]
epoch_acc = torch.stack(batch_accs).mean()
return {'validation_loss': epoch_loss.item(), 'validation_accuracy': epoch_acc.item()}
def epoch_end(self, epoch, result):
print(f"Epoch [{epoch}], val_loss: {result['validation_loss']}, val_acc: {result['validation_acc']}")
model = mnistmodel()
def fit_mnist(epochs, lr, model, training_loader, validation_loader, optimizer_function=torch.optim.SGD):
optimizer = optimizer_function(model.parameters(), lr)
history = []
for epoch in range(epochs):
for batch in training_loader:
loss = model.training_step(batch)
loss.backward()
optimizer.step()
optimizer.zero_grad()
result = evaluate(model, validation_loader)
model.epoch_end(epoch, result)
history.append(result)
return history
history1 = fit_mnist(5, 0.001, model, training_loader, validation_loader)
I get the following error -
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-162-48e4fe0cc2d9> in <module>()
----> 1 history1 = fit_mnist(5, 0.001, model, training_loader, validation_loader)
6 frames
/usr/local/lib/python3.7/dist-packages/torch/nn/functional.py in linear(input, weight, bias)
1751 if has_torch_function_variadic(input, weight):
1752 return handle_torch_function(linear, (input, weight), input, weight, bias=bias)
-> 1753 return torch._C._nn.linear(input, weight, bias)
1754
1755
RuntimeError: mat1 and mat2 shapes cannot be multiplied (3584x28 and 784x10)
I’m new to pytorch but as far as I understand the shapes seem to be fine, what is going wrong here?