Model.eval() harms validation

Hi, everyone! I’m stuck with silly problem. Check out my code for toy example:

def train(model, optim, criterion, train_dataloader, val_dataloader, device, epochs):
    train_acc_ = []
    val_acc_ = []
    for ep in range(1, epochs + 1):
        model.train(True)
        train_acc = 0
        for x, y in tqdm(train_dataloader):
            optim.zero_grad()
            x = x.to(device)
            y = y.to(device)
            out = model(x)
            loss = criterion(out, y)
            loss.backward()
            optim.step()
            train_acc += torch.mean((y == torch.argmax(out, 1)).type(torch.float))
        train_acc /= len(train_dataloader)

        val_acc = 0
        model.train(False)
        with torch.no_grad():
            for x, y in tqdm(val_dataloader):
                x = x.to(device)
                y = y.to(device)
                out = model(x)
                val_acc += torch.mean((y == torch.argmax(out, 1)).type(torch.float))
            val_acc /= len(val_dataloader)
        train_acc = train_acc.detach().cpu().item()
        val_acc = val_acc.detach().cpu().item()
        train_acc_.append(train_acc)
        val_acc_.append(val_acc)
        clear_output()
        plt.plot(np.arange(ep), train_acc_, label=f"Train acc (last: {round(train_acc, 3)})")
        plt.plot(np.arange(ep), val_acc_, label=f"Val acc (last: {round(val_acc, 3)})")
        plt.legend()
        plt.show()
    return train_acc_, val_acc_

tr = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize(train_mean, train_std)
])

device = "cuda:0" if torch.cuda.is_available() else "cpu"

model = torchvision.models.resnet18()

model.conv1 = torch.nn.Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
model.fc = torch.nn.Linear(512, 10)

model = model.to(device)

train_data = datasets.CIFAR10("./cifar10_data", train=True, transform=tr)
val_data = datasets.CIFAR10("./cifar10_data", train=False, transform=tr)

train_dataloader = torch.utils.data.DataLoader(train_data, batch_size=128)
val_dataloader = torch.utils.data.DataLoader(val_data, batch_size=256)

criterion = torch.nn.CrossEntropyLoss()
optim = torch.optim.Adam(model.parameters())

when I run this code my accuracy plot looks like this
image

But when I removed model.train(False) in my train function everything is ok.

Where did I go wrong?

try this model.eval() instead of model.train(False)

Pardon me if couldn’t understand your question well.

Hi!
AFAIK, It’s the same. I’ve tried it also, but still it doesn’t work

1 Like

I’m not sure why you are getting a random validation accuracy, but your code seems to work at least for 3 epochs (I stopped it afterwards):
image