Switching between torch.flatten(x) and x.view(-1, 256)

class CNN(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(1, 6, kernel_size=5) # 1 * 28 * 28 -> 6 * 24 * 24
        self.conv2 = nn.Conv2d(6, 16, kernel_size=5) # 6 * 12 * 12 -> 16 * 8 * 8
        self.pool = nn.MaxPool2d(2)
        self.fc1 = nn.Linear(256, 100)
        self.fc2 = nn.Linear(100, 10)
    def forward(self,x):
        x = self.pool(torch.relu(self.conv1(x)))
        x = self.pool(torch.relu(self.conv2(x)))
        x = torch.flatten(x)
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return F.log_softmax(x, -1)
def training_loop():
    for epoch in range(epochs):
        model.train() 
        for img, label in train_loader:
                img, label = img.to(device), label.to(device)
                out = model(img)
                loss = loss_fn(out, label)
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
        print(loss)
        print(f'epoch {epoch} ended')
        print('training done!')

    torch.save(model.state_dict(), './MNIST.pth')

I have this code here.
I used torch.flatten(x) at first. the fc1 layer expects 2560 input and crossentropyloss expects float data values. After training, the loss is around 100.
I changed torch.flatten(x) to x.view(-1, 256), fc1 layer expects 256 input and crossentropyloss wants Long values. the loss value looks good. (0.01 or something)

Why does this happen?

The usage of torch.flatten(x) sounds wrong as it would flatten the entire tensor as seen here:

x = torch.randn(10, 256)
print(torch.flatten(x).shape)
# torch.Size([2560])

Using x = x.view(-1, 256) sounds better in case you can guarantee the tensor has a feature size of 256. The better approach would be to use x = x.view(x.size(0), -1) to make sure the batch size does not change.

1 Like

Thanks for the help!
This solved my problem, but why does switching between flatten() and view(-1, 256) affect the loss values?

I guess because your loss calculation was wrong in the flatten() use case and e.g. unexpected broadcasting might have been used.

1 Like