Segmentation: RuntimeError: gather(): Expected dtype int64 for index

I’m running a segmentation model for some medical images using the SMP package but I’m having trouble finally training the model. I wrote a dataloader that returns a training image tensor of size [3, 512, 512] and its corresponding target mask of size [1, 512, 512] where each pixel can take the values 0 through 3 for each of the four classes. Then I run the following code:

ENCODER = 'resnet50'
ENCODER_WEIGHTS = 'imagenet'
ACTIVATION = 'softmax2d'  

model = smp.Unet(
    encoder_name=ENCODER,
    encoder_weights=ENCODER_WEIGHTS,
    classes=len(CLASSES),
    activation=ACTIVATION
)

preprocessing_fn = smp.encoders.get_preprocessing_fn(ENCODER, ENCODER_WEIGHTS)

train_dataset = RetinaDataset('./data/new_format/train',
                              data_transform=data_transform,
                              mask_transform=mask_transform,
                              classes=['background', 'irf', 'srf', 'ped'])

val_dataset = RetinaDataset('./data/new_format/val',
                            data_transform=data_transform,
                            mask_transform=mask_transform,
                            classes=['background', 'irf', 'srf', 'ped'])

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
valid_loader = DataLoader(val_dataset, batch_size=1, shuffle=False)

loss = smp.losses.SoftCrossEntropyLoss(smooth_factor=.01)
metrics = [
    utils.metrics.IoU(threshold=0.5),
]

optimizer = torch.optim.Adam([
    dict(params=model.parameters(), lr=0.0001),
])


train_epoch = smp.utils.train.TrainEpoch(
    model,
    loss=loss,
    metrics=metrics,
    optimizer=optimizer,
    verbose=True,
)

valid_epoch = smp.utils.train.ValidEpoch(
    model,
    loss=loss,
    metrics=metrics,
    verbose=True,
)

max_score = 0

for i in range(5):

    print('\nEpoch: {}'.format(i))
    train_logs = train_epoch.run(train_loader)
    valid_logs = valid_epoch.run(valid_loader)

    # do something (save model, change lr, etc.)
    if max_score < valid_logs['iou_score']:
        max_score = valid_logs['iou_score']
        torch.save(model, './models/best_model.pth')
        print('Model saved!')

    if i == 25:
        optimizer.param_groups[0]['lr'] = 1e-5
        print('Decrease decoder learning rate to 1e-5!')

Then I get the error seen in the title of this post. I’ve tried casting the mask to uint64 in so many places but that hasn’t worked. I also feel like there is something wrong with the loss function I’m using.
Thanks in advance

Did you try to cast it via tensor = tensor.long()?

I’m pretty sure I did that at some point and it didn’t work for me. I actually got it working now.

I forgot to put my mask_transform in the first post but originally it was

mask_transform = transforms.Compose([
    transforms.PILToTensor(),
])

then I added transforms.ConvertImageDtype(torch.int64) as the second transform and that fixed it.

I also changed the loss from SoftCrossEntropyLoss to DiceLoss.

Thanks