Model works for inference without .eval() and doesn't with it

Hi,
I’ve trained my model and works great, (it’s trying to count points in images) and when I try to infer using it and the trained weights, I get a sensible value but it varies a couple of percent when run on the same data.
(
I realised I’d not set the model to eval using model.eval() before doing the inference, thinking that not turning off the dropout was the source of the variation. However, when I set model.eval() the variation disappears but the answer is completely wrong now. This stability suggests that it’s not numerical instability causing the initial variation.

I’m also worried that it’s the probabilities for the dropout I’m using

My model is this;

class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1,75, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(75,50, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(50,25, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(75)
        self.bn2 = nn.BatchNorm2d(50)
        self.bn3 = nn.BatchNorm2d(25)
        self.dropout1 = nn.Dropout(0.2)
        self.dropout2 = nn.Dropout(0.3)
        self.fc1 = nn.Linear(25*66*50, 512)
        self.fc2 = nn.Linear(512,512)
        self.fc3 = nn.Linear(512,1)
        self.float()

    def forward(self, x):
        out = F.max_pool2d(self.bn1(torch.relu(self.conv1(x))),2)
        out = F.max_pool2d(self.bn2(self.dropout1(torch.relu(self.conv2(out)))),2)
        out = F.max_pool2d(self.bn3(torch.relu(self.conv3(out))),2)
        out = out.view(-1,25*66*50)
        out = torch.relu(self.fc1(out))
        out = self.dropout2(out)
        out = torch.relu(self.fc2(out))
        out = self.fc3(out)
        return out

Here’s the inference code;

if __name__=="__main__":
    model=Net()
    model.eval()
    device_in_use = get_device()
    weights_file = 'weights.pt'
    model.load_state_dict(torch.load(weights_file,map_location=device_in_use))
    image_to_infer = get_data(sys.argv[1])
    image_to_infer.to(device=device_in_use)

    output = model(image_to_infer).item()

    print(output*max_ions)

The training loop looks like this;

def runModel(model, n_epochs, max_ions, optimizer,loss_fn, train_loader, validation_loader, save_weights_as, save\
_graph_as, writerSuffix="" ):
    device_in_use = get_device()
    max_ions = int(max_ions)
    writer = SummaryWriter()
    model.to(device=device_in_use)
    for epoch in range(1, n_epochs + 1):
        loss_train = 0.0
        for imgs, labels in train_loader:
            imgs_gpu= imgs.to(device=device_in_use)
            outputs = model(imgs_gpu)
            normalised_labels = normalise_label2(labels,max_ions).to(device=device_in_use)
            loss = loss_fn(outputs,normalised_labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            current_loss = loss.item()
            loss_train += current_loss
        writer.add_scalar("loss/training", loss_train/len(train_loader))
        nll =[]
        ol=[]

        validation_loss = 0
        for imgs, labels in validation_loader:
            imgs= imgs.to(device=device_in_use)
            outputs = model(imgs)
            normalised_labels = normalise_label2(labels,max_ions).to(device=device_in_use)
            #print(f'{outputs.shape} {normalised_labels.shape}')
            output_list = outputs.squeeze().tolist()
            normalised_labels_list = normalised_labels.squeeze().tolist()
            nll.extend(normalised_labels_list)
            ol.extend(output_list)
            loss = loss_fn(outputs,normalised_labels)
            validation_loss += loss.item()

        print(f"Epoch:{epoch} Loss:{loss_train/len(train_loader)} Validation:{validation_loss/len(validation_load\er)} len train_loader:{len(train_loader)} len validation_loader: {len(validation_loader)}")
        writer.add_scalar("loss/validation", validation_loss/len(validation_loader))
        nll = [i * max_ions for i in nll]
        ol = [i * max_ions for i in ol]
        save_weights(model, save_weights_as, epoch)
    draw_box_plot(nll, ol, save_graph_as)
    writer.close()

Reading through this, I’m wondering if missing out model.train() in the training cycle is the cause of this? Also, do I have to change this back to model.eval() for the validation step?

Love using PyTorch, this is the first project I’ve done with it! Thanks for any advice you can offer.

Yes, calling model.eval() during validation is the common approach as it would disable dropout layers and use the running stats of batchnorm layers, which would avoid creating batch size - dependent outputs. However, the latter seems to cause issues in your model and the running stats of batchnorm layers might not have properly converged or generally might not represent the stats of the forward activations. This issue is discussed a few times here already and you could e.g. play around with the momentum to see if it could help.

Hi,
Thanks for getting back to me. I started a run last night (for ref, using 500K training images and 70K validation images) with model.train() in the training section and model.eval() in the validation section, and although my train loss is coming down nicely and about where I’d expect, the validation loss is two orders of magnitude greater than I’d expect.

However, I expected something was a bit fishy with the validation loss as it was about 30% of the training loss, which I put down to dropout in the training section.

It’s really odd, by ignoring model.train and model.eval() I seem to have created a huge Frankenstein ensemble network that gives really good results, just slightly inconsistent between runs!

Also, here’s the optimiser I’m using;
optim.Adam(model.parameters(), lr= learning_rate)

where the learning rate is 1e-4.

Hi @JAllsopp,

Could this be the issue? I know when casting a model to a new device, it’s not self-referential, i.e.,

model = model.to("cuda") #works
model.to("cuda") #doesn't work

I believe (@ptrblck will know for sure) that the same applies for casting to a new dtype.

The to() operation (and direct transformations) are applied recursively on modules (and all properly registered parameters, buffers, and submodules):

model = torchvision.models.resnet18()

print(model.conv1.weight.dtype)
# torch.float32

model.double()
print(model.conv1.weight.dtype)
# torch.float64

model.to(torch.bfloat16)
print(model.conv1.weight.dtype)
# torch.bfloat16

so this shouldn’t be an issue.
The model.to("cuda") call should also work, so it would be interesting to hear if you’ve encountered any issues with it.

1 Like