How to compute accuracy when using v2.Mixup?

Hi all,

In the context of v2.Mixup, how would one go about calculating accuracy when ‘lambda,’ target_a, and target_b are not returned?

The original implementation involves accuracy calculation by multiplying with ‘lambda’ and comapring predictions with target, summing the results together (below):

        correct += (lam * predicted.eq(targets_a.data).cpu().sum().float()
                    + (1 - lam) * predicted.eq(targets_b.data).cpu().sum().float())

But v2.MixUp(num_classes=NUM_CLASSES) does not return these values.

Thanks a lot.

from torchvision.transforms import v2
from torch.utils.data import DataLoader

NUM_CLASSES = 10

dataloader = DataLoader(dataset, batch_size=4, shuffle=True)


mixup = v2.MixUp(num_classes=NUM_CLASSES)

for images, labels in dataloader::
            inputs = images.to(device)
            targets = labels .to(device)

            # zero the parameter gradients
            optimizer.zero_grad()

            # forward
            # track history if only in train
            with torch.set_grad_enabled(True):
                outputs = model(inputs)
                loss = criterion(outputs, targets)   # where criterion = nn.CrossEntropyLoss()

                # backward + optimize
                loss.backward()
                optimizer.step()

            # statistics
            train_loss += loss.item()

            # acc =??