Nonsensical Unet output with model.eval()

Hi Forum,

I’m having an issue with a pretrained model (Unet, librabry: Pytorch segmentation models) which I use for the tile-wise segmentation of large (6000x6000) histological images. The model performs well during training and testing, whereas I use model.train() for the training step and model.eval() for the testing step. No unexpected behavior here. I use a batch size of 20 as my GPU only has 8GB RAM. (For snippet of training code, see below)

In a second step, I apply the model to a large amount of images in a separate script. Now, after loading the model with the following code, I get purely nonsensical results (see images below) - when I omit model.eval(), the result looks good. I still get some batch effects, which I can mitigate a bit with proper shuffling.

    model = smp.Unet(
        encoder_name='resnet50', 
        encoder_weights='imagenet', 
        classes=3, 
        activation=None,
    )
    
    model.load_state_dict(torch.load(BST_MODEL))
    model = model.to(DEVICE)
    model.eval()

Here’s a collage of the (histological) images I’m working on - I am aware that it probably won’t mean much to you - but I hope it illustrates the problem better than loss metrics.
|

The difference in code leading to the two outputs is exclusively the .eval() command after moving the model to the GPU.

I suspect that the issue is closely related to things that have been discussed in other threads (like this or this).

Things I have tried so far:

  • track_running_stats=False for batchnorm2d layers
  • varying the momentum parameter for batchnorm2d layers
  • Increasing batch size to 32

Note: Since the images are too large, I process the whole images as 256x256-sized tiles, which I feed as (shuffled) batches through the model. I process the whole image several times with offset tile coordinates. The multiple predictions are then averaged to suppress tile-artifacts.

My code is completely available here. As stated above, train/testing (Train.py) works fine, things go sideways after loading and running the model (Process.py).

I guess I am doing some rookie mistake but I cannot identify the error.
Thank you in advance for your time and help!

PS: Train/Test code snippet:

    # dataloaders & PerformanceMeter
    train_dataloader = DataLoader(train_dataset, BATCH_SIZE, shuffle=True, num_workers=num_workers)
    val_dataloader   = DataLoader(val_dataset, BATCH_SIZE, shuffle=True, num_workers=num_workers)
    Monitor = helper_functions.PerformanceMeter()
    
    # score lists
    train_score = []
    test_score = []
    
    # train
    for epoch in range(EPOCHS):
        
        model.train()
        optimizer.zero_grad()
        tk0 = tqdm(train_dataloader, total=len(train_dataloader))
        
        for b_idx, data in enumerate(tk0):
            
            # move images on GPU
            for key, value in data.items():
                data[key] = value.to(device).float()
                
            # train
            data['prediction']  = model(data['image'])
            loss = criterion(data['prediction'], data['mask'][:, 0].long())
            loss.backward()
            optimizer.step()
            tk0.set_postfix(loss=loss.cpu().detach().numpy())

        # evaluate
        model.eval()
        with torch.no_grad():
            
            tk1 = tqdm(val_dataloader, total=len(val_dataloader))
            dice = np.zeros(N_CLASSES)
            for b_idx, data in enumerate(tk1):
                
                # Eval
                data['prediction'] = model(data['image'].to(device).float())                 
                out = torch.argmax(data['prediction'], dim=1).view(-1)
                mask = data['mask'].view(-1)
                
                score = jaccard_score(mask.cpu(), out.cpu(), average=None)
                dice += score