Gradients of my train labels are not accumulated after validation_loss.backward()

I want to train a model with labels that have gradient. I learn that cross entropy just use hard label and can’t have gradient with labels, so i use a custom cross entropy with one hot encoded label. I calculate train loss for 1 batch of training image and use the updated model to calculate validation loss on my the whole validation dataset. I sum all the validation loss and backward but the gradients of training labels are still unchanged and is the same as the first backward for train loop.

def softXEnt(input, target):
    logprobs = torch.nn.functional.log_softmax (input, dim = 1)
    return  -(target * logprobs).sum() / input.shape[0]
    for epoch in range(num_epochs):
        model.train()
        total_loss = 0 
        for i,b in enumerate(train_loaders):
            images = b[0].cuda()
            labels = b[1].cuda()
            one_hot_encode_size = (labels.size()[0],num_class) 
            one_hot_label = torch.zeros(one_hot_encode_size).cuda().scatter_(1,labels.unsqueeze(-1).expand(one_hot_encode_size),1)
            out = model(images)
            loss = softXEnt(out,one_hot_label.requires_grad_())
            optimizer.zero_grad() 
            loss.backward()
            optimizer.step()        
            scheduler.step()
            total_loss += loss.item()
            print(f"train loss:{loss},step:{i}")
            val_loss = torch.tensor(0,dtype=torch.float).cuda()
            for b in val_loaders:
                images = b[0].cuda()
                labels = b[1].cuda()
                val_one_hot_encode_size = (labels.size()[0],num_class) 
                val_one_hot_label = torch.zeros(val_one_hot_encode_size).cuda().scatter_(1,labels.unsqueeze(-1).expand(val_one_hot_encode_size),1)
                out = model(images)
                loss = softXEnt(out,val_one_hot_label)
                val_loss += loss 
            val_loss.backward()

I guess that in optimizer step, it stops tracking operations so the gradient flow from validation loop will not back-propagate to train loop? please help!

Hi Nguyen!

As of some versions ago, CrossEntropyLoss does support a target that
consists of probabilistic (“soft”) labels, and you can backpropagate through
the probabilistic version of target.

Best.

K. Frank

1 Like

oh,really! can you suggest me which version? i currently use torch=1.10+cu113 on gpu A6000. Do you think with newer torch version, is there anyway it can backprop to train targets?

Hi Nguyen!

The 1.10 documentation for CrossEntropyLoss states that a probabilistic
target is supported, and, from memory, this was added prior to 1.10.

If this doesn’t work for you, could you post a complete, runnable script
that illustrates the issue? (Although it’s moot now, it would be worth knowing
that the 1.10 documentation is misleading.) In any event, it works for me
on 1.12, so you might consider upgrading to the latest stable version, 1.13
(which might be worth doing anyway).

Best.

K. Frank

hello,yeah, i’ve tried the probabilistic labels on crossentropy and it worked. But do you see the gradient of train label changed after validation_loss’s backward?. I still have the same gradient’s train labels after validation_loss.backward(), my thought is when optimizer updates parameters like $\theta = \theta-\alpha*\nabla L_{train}(x,trainLabel)$ it doesn’t save gradient function for params like subdivbackward,basically they are leaf tensors so they will not backpropagte through update step.

Hi Nguyen!

I did not understand the full scope of your original question.

As I understand it, you backpropagate a loss that depends on your
training labels and then perform an optimization step. The updated
parameters of your model therefore depend on your training labels.

You then perform a forward pass on your validation data using the
updated model and calculate the validation loss with respect to the
validation labels. This validation loss therefore also depends on your
training labels.

You would like to compute the gradient of your validation loss with
respect to your training labels – a logically-consistent thing to do.

As you note, you can’t do this with backward() because pytorch
doesn’t backpropagate through the optimization step. In fact, pytorch
appears to be quite fastidious about preventing such backpropagation.

The only approach I can think of would be to use numerical differentiation
with respect to the training labels. Although plausible, such an approach
doesn’t really seem very attractive. It’s cumbersome, could become
expensive, and numerical differentiation is generally somewhat tricky.

Good luck!

K. Frank