I’m having an issue with a pretrained model (Unet, librabry: Pytorch segmentation models) which I use for the tile-wise segmentation of large (6000x6000) histological images. The model performs well during training and testing, whereas I use model.train() for the training step and model.eval() for the testing step. No unexpected behavior here. I use a batch size of 20 as my GPU only has 8GB RAM. (For snippet of training code, see below)
In a second step, I apply the model to a large amount of images in a separate script. Now, after loading the model with the following code, I get purely nonsensical results (see images below) - when I omit
model.eval(), the result looks good. I still get some batch effects, which I can mitigate a bit with proper shuffling.
model = smp.Unet( encoder_name='resnet50', encoder_weights='imagenet', classes=3, activation=None, ) model.load_state_dict(torch.load(BST_MODEL)) model = model.to(DEVICE) model.eval()
Here’s a collage of the (histological) images I’m working on - I am aware that it probably won’t mean much to you - but I hope it illustrates the problem better than loss metrics.
The difference in code leading to the two outputs is exclusively the .eval() command after moving the model to the GPU.
Things I have tried so far:
- track_running_stats=False for batchnorm2d layers
- varying the momentum parameter for batchnorm2d layers
- Increasing batch size to 32
Note: Since the images are too large, I process the whole images as 256x256-sized tiles, which I feed as (shuffled) batches through the model. I process the whole image several times with offset tile coordinates. The multiple predictions are then averaged to suppress tile-artifacts.
My code is completely available here. As stated above, train/testing (Train.py) works fine, things go sideways after loading and running the model (Process.py).
I guess I am doing some rookie mistake but I cannot identify the error.
Thank you in advance for your time and help!
PS: Train/Test code snippet:
# dataloaders & PerformanceMeter train_dataloader = DataLoader(train_dataset, BATCH_SIZE, shuffle=True, num_workers=num_workers) val_dataloader = DataLoader(val_dataset, BATCH_SIZE, shuffle=True, num_workers=num_workers) Monitor = helper_functions.PerformanceMeter() # score lists train_score =  test_score =  # train for epoch in range(EPOCHS): model.train() optimizer.zero_grad() tk0 = tqdm(train_dataloader, total=len(train_dataloader)) for b_idx, data in enumerate(tk0): # move images on GPU for key, value in data.items(): data[key] = value.to(device).float() # train data['prediction'] = model(data['image']) loss = criterion(data['prediction'], data['mask'][:, 0].long()) loss.backward() optimizer.step() tk0.set_postfix(loss=loss.cpu().detach().numpy()) # evaluate model.eval() with torch.no_grad(): tk1 = tqdm(val_dataloader, total=len(val_dataloader)) dice = np.zeros(N_CLASSES) for b_idx, data in enumerate(tk1): # Eval data['prediction'] = model(data['image'].to(device).float()) out = torch.argmax(data['prediction'], dim=1).view(-1) mask = data['mask'].view(-1) score = jaccard_score(mask.cpu(), out.cpu(), average=None) dice += score