Hi Forum,
I’m having an issue with a pretrained model (Unet, librabry: Pytorch segmentation models) which I use for the tile-wise segmentation of large (6000x6000) histological images. The model performs well during training and testing, whereas I use model.train() for the training step and model.eval() for the testing step. No unexpected behavior here. I use a batch size of 20 as my GPU only has 8GB RAM. (For snippet of training code, see below)
In a second step, I apply the model to a large amount of images in a separate script. Now, after loading the model with the following code, I get purely nonsensical results (see images below) - when I omit model.eval()
, the result looks good. I still get some batch effects, which I can mitigate a bit with proper shuffling.
model = smp.Unet(
encoder_name='resnet50',
encoder_weights='imagenet',
classes=3,
activation=None,
)
model.load_state_dict(torch.load(BST_MODEL))
model = model.to(DEVICE)
model.eval()
Here’s a collage of the (histological) images I’m working on - I am aware that it probably won’t mean much to you - but I hope it illustrates the problem better than loss metrics.
|
The difference in code leading to the two outputs is exclusively the .eval() command after moving the model to the GPU.
I suspect that the issue is closely related to things that have been discussed in other threads (like this or this).
Things I have tried so far:
- track_running_stats=False for batchnorm2d layers
- varying the momentum parameter for batchnorm2d layers
- Increasing batch size to 32
Note: Since the images are too large, I process the whole images as 256x256-sized tiles, which I feed as (shuffled) batches through the model. I process the whole image several times with offset tile coordinates. The multiple predictions are then averaged to suppress tile-artifacts.
My code is completely available here. As stated above, train/testing (Train.py) works fine, things go sideways after loading and running the model (Process.py).
I guess I am doing some rookie mistake but I cannot identify the error.
Thank you in advance for your time and help!
PS: Train/Test code snippet:
# dataloaders & PerformanceMeter
train_dataloader = DataLoader(train_dataset, BATCH_SIZE, shuffle=True, num_workers=num_workers)
val_dataloader = DataLoader(val_dataset, BATCH_SIZE, shuffle=True, num_workers=num_workers)
Monitor = helper_functions.PerformanceMeter()
# score lists
train_score = []
test_score = []
# train
for epoch in range(EPOCHS):
model.train()
optimizer.zero_grad()
tk0 = tqdm(train_dataloader, total=len(train_dataloader))
for b_idx, data in enumerate(tk0):
# move images on GPU
for key, value in data.items():
data[key] = value.to(device).float()
# train
data['prediction'] = model(data['image'])
loss = criterion(data['prediction'], data['mask'][:, 0].long())
loss.backward()
optimizer.step()
tk0.set_postfix(loss=loss.cpu().detach().numpy())
# evaluate
model.eval()
with torch.no_grad():
tk1 = tqdm(val_dataloader, total=len(val_dataloader))
dice = np.zeros(N_CLASSES)
for b_idx, data in enumerate(tk1):
# Eval
data['prediction'] = model(data['image'].to(device).float())
out = torch.argmax(data['prediction'], dim=1).view(-1)
mask = data['mask'].view(-1)
score = jaccard_score(mask.cpu(), out.cpu(), average=None)
dice += score