How to make Cross Entropy Loss work with Cutmix & Mixup?

Hi everyone,
I’m trying to follow the steps of the official tutorial on how to implement both cutmix and mixup during training to perform augmentation but when I start training i get this runtime error from the criterion call.

0D or 1D target tensor expected, multi-target not supported

The following is my training code

    cutmix = v2.CutMix(num_classes=NUM_CLASSES)
    mixup = v2.MixUp(num_classes=NUM_CLASSES)
    cutmix_or_mixup = v2.RandomChoice([cutmix, mixup])
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)    

    for epoch in range(args.epochs):
        train_losses = [] 
        train_acc = 0.0
        total=0
        print(f"[Epoch {epoch+1} / {args.epochs}]")
        
        model.train()
        pbar = tqdm(data_loader)
        for i, (x, y) in enumerate(pbar):
            image = x.to(args.device)
            label = y.to(args.device) 
            image, label = cutmix_or_mixup(image, label)
            optimizer.zero_grad()

            output = model(image)
            label = label.squeeze()
            
            loss = criterion(output, label)
            loss.backward()
            optimizer.step()

            train_losses.append(loss.item())
            total += label.size(0)

            train_acc += acc(output, label)

        epoch_train_loss = np.mean(train_losses)
        epoch_train_acc = train_acc/total

        print(f'Epoch {epoch+1}') 
        print(f'train_loss : {epoch_train_loss}')
        print('train_accuracy : {:.3f}'.format(epoch_train_acc*100))

What am I missing? The tutorial says that I can pass the transformed labels as-is to a loss function like cross entropy.
Tutorial: link

Thanks in advance

If you are passing one-hot encoded labels, make sure they are passed as a floating point tensor. This feature was introduced a few releases ago and allows you to pass “soft” labels to nn.CrossEntropyLoss.

I added label = label.to(torch.float) before the criterion call but I keep getting the same error.

The tutorial works fine for me, so could you post a minimal and executable code snippet reproducing the issue, please?