Shape Error PyTorch

I’m trying to build a simple MNIST Model and this is what I’ve built -

training_loader = DataLoader(training_dataset, 128, shuffle = True)
validation_loader = DataLoader(validation_dataset, 128)

class mnistmodel(nn.Module):
  def __init__(self):
    self.linear1 = nn.Linear(784, 10)
    self.linear2 = nn.Linear(10, 5)
    self.linear3 = nn.Linear(5, 10)  

  def forward(self, xb):
    xb.reshape(-1, 784)
    predicted = F.relu(self.linear1(xb))
    predicted.reshape(-1, 10)
    predicted = F.relu(self.linear2(predicted))
    predicted.reshape(-1, 5)
    predicted = self.linear3(predicted)
    return predicted
  def training_step(self, batch):
    images, labels = batch
    predicted = self(images)
    loss = F.cross_entropy(predicted, labels)
    return loss
  def validation_step(self, batch):
    images, labels = batch
    predicted = self(images)
    loss = F.cross_entropy(predicted, labels)
    _, preds = torch.max(predicted, dim=1)
    accuracy = torch.tensor(torch.sum(preds == labels).item() / len(preds))
    return {'validation_loss': loss, 'validation_accuracy': accuracy}

  def validation_epoch_end(self, outputs):
    batch_losses = [x['validation_loss'] for x in outputs]
    epoch_loss = torch.stack(batch_losses).mean()
    batch_accs = [x['validation_acc'] for x in outputs]
    epoch_acc = torch.stack(batch_accs).mean()
    return {'validation_loss': epoch_loss.item(), 'validation_accuracy': epoch_acc.item()}
def epoch_end(self, epoch, result):
    print(f"Epoch [{epoch}], val_loss: {result['validation_loss']}, val_acc: {result['validation_acc']}")
model = mnistmodel()

def fit_mnist(epochs, lr, model, training_loader, validation_loader, optimizer_function=torch.optim.SGD):
    optimizer = optimizer_function(model.parameters(), lr)
    history = []
    for epoch in range(epochs):
        for batch in training_loader:
            loss = model.training_step(batch)
        result = evaluate(model, validation_loader)
        model.epoch_end(epoch, result)

    return history

history1 = fit_mnist(5, 0.001, model, training_loader, validation_loader)

I get the following error -

RuntimeError                              Traceback (most recent call last)
<ipython-input-162-48e4fe0cc2d9> in <module>()
----> 1 history1 = fit_mnist(5, 0.001, model, training_loader, validation_loader)

6 frames
/usr/local/lib/python3.7/dist-packages/torch/nn/ in linear(input, weight, bias)
   1751     if has_torch_function_variadic(input, weight):
   1752         return handle_torch_function(linear, (input, weight), input, weight, bias=bias)
-> 1753     return torch._C._nn.linear(input, weight, bias)

RuntimeError: mat1 and mat2 shapes cannot be multiplied (3584x28 and 784x10)

I’m new to pytorch but as far as I understand the shapes seem to be fine, what is going wrong here?

reshape is not an inplace operation, so you need to assign the return value to another object:

xb = xb.reshape(-1, 784)

Generally, I would recommend to use this approach:

xb = xb.view(xb.size(0), -1)

to keep the batch dimension and to get a better error message in case the feature dimension is incorrect.

Also, you wouldn’t need to reshape the activations, as the linear layers should already return the expected output.

Thanks a lot this clears it up