Hello,
I’m new to Pytorch and I’m trying to perform semantic segmentation of 3d images. When I use cross entropy loss with ignore index, I get the above error.
def training_fn(net,
device,
input_dim,
epochs: int = 1,
batch_size: int = 1,
learning_rate: float = 1e-3,
valiation_percent=0.1,
save_checkpoint: bool = True):
# create dataset
dataset = hdf5.Hdf5Dataset(data_file_path, image_dim=input_dim, contains_mask=True)
# create training and validation dataset
n_dataset = dataset.__len__()
n_val = round(n_dataset * valiation_percent)
n_train = n_dataset - n_val
train_set, val_set = random_split(dataset, [n_train, n_val])
# create dataloaders
train_dataloader = DataLoader(train_set, shuffle=True, batch_size=batch_size, num_workers=1, pin_memory=True)
val_dataloader = DataLoader(val_set, shuffle=False, batch_size=batch_size, num_workers=1, pin_memory=True)
# specify loss functions, optimizers
loss_fn = nn.CrossEntropyLoss(ignore_index=2)
optimizer = torch.optim.Adam(net.parameters(), lr=learning_rate)
scheduler = lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)
for epoch in range(1, epochs + 1):
print('Epoch {}/{}'.format(epoch, epochs))
print('-' * 10)
if os.path.exists(checkpoint_path): # checking if there is a file with this name
checkpoint = torch.load(checkpoint_path)
net.load_state_dict(checkpoint)
# scheduler.step()
for param_group in optimizer.param_groups:
print("LR", param_group['lr'])
net.to(device)
net.train()
running_loss = 0
i = 0
# training
for batch in train_dataloader:
image = batch[0]
image = image.permute(1, 0, 2, 3)
true_mask = batch[1]
image = image.to(device=device, dtype=torch.float32)
true_mask = true_mask.to(device=device, dtype=torch.int64)
true_mask = F.one_hot(true_mask, config.n_classes)
true_mask = true_mask.permute(0, 1, 4, 2, 3)
true_mask = true_mask[-1, :, :, :]
true_mask = true_mask.type(torch.float32)
optimizer.zero_grad()
pred = net(image)
# pred = pred.permute(1, 0, 2, 3)
loss = loss_fn(torch.round(pred), torch.round(true_mask))
i += 1
# Backpropagation
loss.backward()
optimizer.step()
running_loss += loss.item()
# print(f'Accuracy score, Hamming loss - : {mIoU(pred, true_mask)}')
if (i == n_train):
print(f'Epoch : {epoch}, loss: {(running_loss / i):.4f}')
I have three labels - 0, 1, 2. I tried to round values so that it would be integer, but it is not working. Any suggestions would be helpful.