How to print the probability of correct category of a labeled example?

Hi, I’m trying to implement the following paper, UDA

Unsupervised Data Augmentation” which was released couple of months ago

I can understand about the Training Signal Annealing which removes part of the training example

but the problem is my coding ability. I’m having difficulty coding the above lines.
following is my codes of applying UDA hope someone can help me

        criterion = nn.CrossEntropyLoss()
        train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
        unsupervised_loader = itertools.cycle(unsup_loader)

    for i, (inputs, targets) in enumerate(train_loader):
            
        if not unsupervised:
             inputs, targets = inputs.cuda(), targets.cuda()
             outputs = model(inputs)
             loss = criterion(outputs, targets)
            
        else:
            targets = targets.cuda()
            unlabel1, unlabel2 = next(unsupervised_loader)
            data_all = torch.cat([inputs, unlabel1, unlabel2]).cuda()

            preds_all = model(data_all)
            preds = preds_all[:len(inputs)]
            loss = crietrion(preds, targets) # loss for supervised learning

            preds_unsup = preds_all[len(data):]
            preds1, preds2 = torch.chunk(preds_unsup, 2)
            preds1 = softmax(preds1, dim=1).detach()
            preds2 = log_softmax(preds2, dim=1)
            assert len(preds1) == len(preds2) == batch_unsup

            loss_kldiv = kl_div(preds2, preds1, reduction='none') # loss for unsupervised learning
            loss_kldiv = torch.sum(loss_kldiv, dim=1)
            loss += torch.mean(loss_kldiv) * ratio_unsup   # ratio_unsup = len(test) / (len(train) + len(test))

brief description of my code: first, I made separate loaders for supervised and unsupervised. I concatenated these and predicted through the model. and for unsupervised examples I calculated kl-divergence.

how can I remove predicted examples which are above threshold and not put it in the loss function?
any comment would be a great help to me. thanks!!