I’m running a segmentation model for some medical images using the SMP package but I’m having trouble finally training the model. I wrote a dataloader that returns a training image tensor of size [3, 512, 512] and its corresponding target mask of size [1, 512, 512] where each pixel can take the values 0 through 3 for each of the four classes. Then I run the following code:
ENCODER = 'resnet50'
ENCODER_WEIGHTS = 'imagenet'
ACTIVATION = 'softmax2d'
model = smp.Unet(
encoder_name=ENCODER,
encoder_weights=ENCODER_WEIGHTS,
classes=len(CLASSES),
activation=ACTIVATION
)
preprocessing_fn = smp.encoders.get_preprocessing_fn(ENCODER, ENCODER_WEIGHTS)
train_dataset = RetinaDataset('./data/new_format/train',
data_transform=data_transform,
mask_transform=mask_transform,
classes=['background', 'irf', 'srf', 'ped'])
val_dataset = RetinaDataset('./data/new_format/val',
data_transform=data_transform,
mask_transform=mask_transform,
classes=['background', 'irf', 'srf', 'ped'])
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
valid_loader = DataLoader(val_dataset, batch_size=1, shuffle=False)
loss = smp.losses.SoftCrossEntropyLoss(smooth_factor=.01)
metrics = [
utils.metrics.IoU(threshold=0.5),
]
optimizer = torch.optim.Adam([
dict(params=model.parameters(), lr=0.0001),
])
train_epoch = smp.utils.train.TrainEpoch(
model,
loss=loss,
metrics=metrics,
optimizer=optimizer,
verbose=True,
)
valid_epoch = smp.utils.train.ValidEpoch(
model,
loss=loss,
metrics=metrics,
verbose=True,
)
max_score = 0
for i in range(5):
print('\nEpoch: {}'.format(i))
train_logs = train_epoch.run(train_loader)
valid_logs = valid_epoch.run(valid_loader)
# do something (save model, change lr, etc.)
if max_score < valid_logs['iou_score']:
max_score = valid_logs['iou_score']
torch.save(model, './models/best_model.pth')
print('Model saved!')
if i == 25:
optimizer.param_groups[0]['lr'] = 1e-5
print('Decrease decoder learning rate to 1e-5!')
Then I get the error seen in the title of this post. I’ve tried casting the mask to uint64 in so many places but that hasn’t worked. I also feel like there is something wrong with the loss function I’m using.
Thanks in advance