Torch.max() losing gradients

Hi, everyone!

I am writing a neural classifier and its output is two classes, with a batch size of 5, so output is a tensor of size (5, 2).

Also, I am using BCEWithLogitsLoss as the loss function.

As you know, BCEWithLogitsLoss accepts a vector of integers (one for each element in the batch) and I have a one-hot vector of two elements as the output of my network.

In order to convert from one-hot to scalar class index, I am using the max method, taking the indices, but, when I do so, I receive the following error at the .backprop method:

RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

Please, is there a different way in Pytorch to convert from one-hot to scalar class index?

Here goes the relevant part of the code:

class ChoamNet(nn.Module):

    def __init__(self):
        super(ChoamNet, self).__init__()

        # Model Architecture.
        self.adapt_in = nn.Conv1d(in_channels=21, out_channels=256, kernel_size=1)
        self.act_in = nn.PReLU(256)
        
        self.block01_01 = ResNextBlock(256, 32, 256, 32)
        self.block01_02 = ResNextBlock(256, 32, 256, 32)
        self.block01_03 = ResNextBlock(256, 32, 256, 32)
        self.block01_04 = ResNextBlock(256, 32, 256, 32)
        
        self.adapt_01_02 = nn.Conv1d(in_channels=256, out_channels=512, kernel_size=1)
        self.act_01_02 = nn.PReLU(512)
        
        self.block02_01 = ResNextBlock(512, 64, 512, 32)
        self.block02_02 = ResNextBlock(512, 64, 512, 32)
        self.block02_03 = ResNextBlock(512, 64, 512, 32)
        self.block02_04 = ResNextBlock(512, 64, 512, 32)
        
        self.adapt_02_03 = nn.Conv1d(in_channels=512, out_channels=1024, kernel_size=1)
        self.act_02_03 = nn.PReLU(1024)
        
        self.block03_01 = ResNextBlock(1024, 128, 1024, 64)
        self.block03_02 = ResNextBlock(1024, 128, 1024, 64)
        self.block03_03 = ResNextBlock(1024, 128, 1024, 64)
        self.block03_04 = ResNextBlock(1024, 128, 1024, 64)
        self.block03_05 = ResNextBlock(1024, 128, 1024, 64)
        self.block03_06 = ResNextBlock(1024, 128, 1024, 64)
        self.block03_07 = ResNextBlock(1024, 128, 1024, 64)
        self.block03_08 = ResNextBlock(1024, 128, 1024, 64)
        self.block03_09 = ResNextBlock(1024, 128, 1024, 64)
        self.block03_10 = ResNextBlock(1024, 128, 1024, 64)
        self.block03_11 = ResNextBlock(1024, 128, 1024, 64)
        
        self.fc1 = nn.Linear(14 * 1024, 2)
        
    def forward(self, x):

        x = self.adapt_in(x)
        x = self.act_in(x)
        
        x = self.block01_01(x)
        x = self.block01_02(x)
        x = self.block01_03(x)
        x = self.block01_04(x)
        
        x = self.adapt_01_02(x)
        x = self.act_01_02(x)

        x = self.block02_01(x)
        x = self.block02_02(x)
        x = self.block02_03(x)
        x = self.block02_04(x)
        
        x = self.adapt_02_03(x)
        x = self.act_02_03(x)

        x = self.block03_01(x)
        x = self.block03_02(x)
        x = self.block03_03(x)
        x = self.block03_04(x)
        x = self.block03_05(x)
        x = self.block03_06(x)
        x = self.block03_07(x)
        x = self.block03_08(x)
        x = self.block03_09(x)
        x = self.block03_10(x)
        x = self.block03_11(x)
        
        x = x.view(-1, 14 * 1024)
        
        x = F.softmax(self.fc1(x), 1)
        
        return x
def perform_training(model, trn_data_i, trn_data_o, tst_data_i, tst_data_o):
    
    model.apply(init_weights)
    
    #loss = nn.CrossEntropyLoss()
    loss = nn.BCEWithLogitsLoss()
    optim = opt.Adam(model.parameters(), lr=1e-3)
    
    tmp_trn_loss = []
    tmp_trn_acc = []
    
    tmp_tst_loss = []
    tmp_tst_acc = []
    
    tmp_duration = []
    
    tmp_epoch = []
    found = False
    fnum = 1
    file = r'D:\Project CHOAM\Results\Epoch 001.mdl'
    found_file = file
    while os.path.exists(file) and os.path.isfile(file):
        found_file = r'D:\Project CHOAM\Results\Epoch ' + "{:03d}".format(fnum) + '.mdl'
        fnum += 1
        file = r'D:\Project CHOAM\Results\Epoch ' + "{:03d}".format(fnum) + '.mdl'
        found = True
        
    if found:
        stats = read_csv(r'D:\Project CHOAM\Results\Progress.csv')
        
        tmp_trn_loss = list(stats['Training Loss'])
        tmp_trn_acc = list(stats['Training Precision'])
        
        tmp_tst_loss = list(stats['Test Loss'])
        tmp_tst_acc = list(stats['Test Precision'])
        
        tmp_duration = list(stats['Duration'])
        
        tmp_epoch = list(stats['Epoch'])
        
        model.load_state_dict(torch.load(found_file))
        model.train()
    
    for e in list(range(10000000))[fnum:]:
        y_true = []
        y_pred = []

        start = datetime.now()
        
        print('Performing training epoch number ' + str(e))
        
        stat_trn_loss = 0
        stat_trn_corr = 0
        stat_trn_tot = 0
        
        stat_tst_loss = 0
        stat_tst_corr = 0
        stat_tst_tot = 0
        
        # Run the training minibatches.
        for b in track(range(len(trn_data_i))):
            model.apply(check_nan)
            
            batch_i = trn_data_i[b]
            batch_o = trn_data_o[b]

            model.zero_grad()

            trn_out = model(batch_i)
            outputs = trn_out.max(1).indices.float()

            trn_loss = loss(outputs, batch_o.float())

            # Backpropagate errors.
            trn_loss.backward()
            optim.step()

            # Calculate training statistics.
            with torch.no_grad():
                stat_trn_loss += trn_loss / batch_o.size()[0]

                stat_trn_corr += float((trn_out.max(1).indices == batch_o).sum())
                stat_trn_tot += float(batch_o.size()[0])

                for p in trn_out.max(1).indices:
                    y_pred.append(p.item())

                for r in batch_o:
                    y_true.append(r.item())

        # Run the test minibatches.
        with torch.no_grad():
            model.eval()

            # Please, don't bother with this part... I will use the same strategy that works above.
            for b in track(range(len(tst_data_i))):
                batch_i = tst_data_i[b]
                batch_o = tst_data_o[b]

                tst_out = model(batch_i)

                #output = tst_out
                output = tst_out.max(2).indices
                labels = batch_o

                tst_loss = loss(output, labels)

                # Calculate test statistics.
                with torch.no_grad():
                    stat_tst_loss += tst_loss / batch_o.size()[0]

                    stat_tst_corr += float((tst_out.max(1).indices == batch_o).sum())
                    stat_tst_tot += float(batch_o.size()[0])

            model.train()

        # Calculate epoch duration.
        end = datetime.now()
        duration = end-start

        with torch.no_grad():
            tmp_trn_loss.append(stat_trn_loss.item())
            tmp_trn_acc.append((float(stat_trn_corr) / float(stat_trn_tot)))

            tmp_tst_loss.append(stat_tst_loss.item())
            tmp_tst_acc.append((float(stat_tst_corr) / float(stat_tst_tot)))
            print((stat_tst_corr, stat_tst_tot))

            dur = str(duration)
            dur = dur[0:dur.index('.')]
            tmp_duration.append(dur)

            tmp_epoch.append(e)

            df = DataFrame()

            df['Training Loss'] = tmp_trn_loss
            df['Training Precision'] = tmp_trn_acc

            df['Test Loss'] = tmp_tst_loss
            df['Test Precision'] = tmp_tst_acc

            df['Duration'] = tmp_duration

            df['Epoch'] = tmp_epoch

            df.to_csv(r'D:\Project CHOAM\Results\Progress.csv')

            torch.save(model.state_dict(), r'D:\Project CHOAM\Results\Epoch ' + "{:03d}".format(e) + '.mdl')

            clear_output()

            conf = confusion_matrix(y_true, y_pred)

            fig, ax = plt.subplots()
            heatmap(conf, annot=True, cbar=True, ax=ax)
            plt.ylabel('True')
            plt.xlabel('Predicted')
            plt.title('Training')
            plt.show()

Thanks in advance!

That’s not correct, as nn.BCEWithLogitsLoss expects raw logits as the model output (FloatTensors) and a target tensor with the same shape and type as the output containing values in [0, 1].

I’m a bit confused about your use case at the moment.
If you are dealing with a multi-label classification (each sample can belong to more than a single class), remove the softmax in your model and just pass the logits to the criterion without any max operations.

On the other hand, if you are dealing with a multi-class classification (each sample belongs to one class only), still remove the softmax, use nn.CrossEntropyLoss as the criterion, and pass the targets as class indices (e.g. by using torch.argmax(target).

Thank you very much! That solved it!

The code I posted was a mess, due to several failed attempts at correction, but you gave me the answer in the “torch.argmax” function.

Cheers!