【need help】loss and acc stay the same too early by using BCEWithLogitsLoss

I use BCEWithLogitsLoss for multi-label and multi-task learning, but when training, after 2 epochs, everythin will not change…
only in the 1st and 2nd epoch, the traing acc increased, even though the validation acc, loss not changed.

From the 3rd epoch, everything will not change,

when i use BCEWithLogitsLoss, I have read this post, i assume i have used it in the right way, here are the details, is there anyone can help me…?

(To illustrate the problem briefly, I omitted the calculation of BACC (balanced accuracy) in the code)

main func

def main():
    best_acc = 0
    best_bacc = 0
    f = open(r'../config.yaml')
    conf = yaml.load(f, Loader=Loader)
    conf = set_data_dir(conf)
    args = parse()
    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
    model = models.__dict__['Alexnet2fc']()

    attrWeights = attribute_weights(celeTrain=True, lfwTrain=True)
    criterions = [nn.BCEWithLogitsLoss(weight=attrWeights[i], reduction='none').cuda() for i in range(40)]

    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=conf['weight_decay'])

    model = torch.nn.DataParallel(model).cuda()
    cudnn.benchmark = True
    train_loader, val_loader, test_loader_celeba, test_loader_lfwa = data_loader(args, conf, )
    num_epochs = conf['epochs']
    taskNum = args.t
    for epoch in range(num_epochs):
        lr = adjust_learning_rate(conf, args, optimizer, epoch)
        train_loss, train_acc, train_bacc = train(model, optimizer, train_loader, criterions, taskNum, )
        val_loss, val_acc, val_bacc = validate(model, val_loader, criterions, taskNum, )

my model arch

class Alexnet2fc(nn.Module):
    def __init__(self, taskNum=40, bitsPerAttr=1, ):
        super(Alexnet2fc, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2, return_indices=False),
            nn.Conv2d(64, 192, kernel_size=5, padding=2),
            nn.BatchNorm2d(192),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2, return_indices=False),
            nn.Conv2d(192, 384, kernel_size=3, padding=1),
            nn.BatchNorm2d(384),
            nn.ReLU(inplace=True),
            nn.Conv2d(384, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2, return_indices=False), 

            nn.Linear(2, 512),
            nn.ReLU(),
            nn.Linear(512, 1024),

        )

        self.avgpool = nn.AdaptiveAvgPool2d((2, 2)) 

        fc = nn.Sequential(
            nn.Linear(1024, 512),
            nn.ReLU(),
            nn.Linear(512, bitsPerAttr),
        )

        self.fc = fc
        self.towers = nn.ModuleList([self.fc for _ in range(taskNum)])

    def forward(self, x):
        x = self.features(x) 
        x = self.avgpool(x) 
        x = torch.flatten(x, 1) 
        out = [tower(x) for tower in self.towers] 
        return out

training code:

def train(model, optimizer, train_loader, criterions, taskNum, ):
    model.train()
    taskAttrNum = 40 // taskNum # num of attrs to pred in one task
    train_loss, corrects, = [0] * taskNum, [0] * taskNum, 

    for i, (inputs, labels) in enumerate(train_loader):
        inputs = inputs.cuda(non_blocking=True)  
        labels = labels.cuda(non_blocking=True)  
        outputs = model(inputs)  
        batch_size = labels.size(0)
        loss, correct,  = acc(criterions, outputs, labels, taskNum)
        loss_sum = sum(loss)
        optimizer.zero_grad()
        loss_sum.backward()
        optimizer.step()
        corrects = [corrects[i]+correct[i] for i in range(taskNum)]

        samplesNum = (i+1) * batch_size
        acc = [100 * corrects[i] / (samplesNum * taskAttrNum) for i in range(taskNum)]

        loss_avg = loss_sum / samplesNum
        acc_avg = sum(acc) / len(acc)
    return loss_avg, acc_avg

loss and accuracy code

def acc(criterions, outputs, labels, taskNum, ):
    taskAttrNum = 40//taskNum 
    loss, acc = [], []
    correct = []
    batch_size = labels.shape[0] 

    for j, output in enumerate(outputs):  
        startIdx = taskAttrNum * j
        endIdx = taskAttrNum * j + taskAttrNum

        label = labels[:, (startIdx):(endIdx)].float()  

        output = output.reshape((batch_size, -1)) 
        criterion = criterions[j]
        taskLoss = criterion(output, label).sum()
        loss.append(taskLoss)
        output = output > 0
        output = output.to(torch.float)

        taskAcc, taskCorrect = cal_acc(output, label, taskNum)
        correct.append(taskCorrect)
    return loss, correct

def cal_acc(output, target, taskNum):
    taskAttrNum = 40 // taskNum  
    with torch.no_grad():
        batch_size = target.size(0)
        correct = torch.eq(output, target).sum().float().item()
        acc = 100 * correct / (batch_size * taskAttrNum)
        return acc, correct

the despairing output:

Epoch: [1 | 90] LR: 0.000100
Training |################################| Batch (1311/1311) | Acc 97.4970 | BAcc 49.6422 | Loss 0.0004 | Total 0:09:03 | ETA 0:00:01
Validating |################################| Batch (165/165) | Acc 77.6407 | BAcc 38.5189 | Loss 0.0039 | Total 0:01:50 | ETA 0:00:01

Epoch: [2 | 90] LR: 0.000100
Training |################################| Batch (1311/1311) | Acc 97.6261 | BAcc 38.5228 | Loss 0.0004 | Total 0:07:07 | ETA 0:00:01
Validating |################################| Batch (165/165) | Acc 77.6407 | BAcc 38.5189 | Loss 0.0039 | Total 0:00:46 | ETA 0:00:01

Epoch: [3 | 90] LR: 0.000100
Training |################################| Batch (1311/1311) | Acc 97.6261 | BAcc 38.5228 | Loss 0.0004 | Total 0:07:09 | ETA 0:00:01
Validating |################################| Batch (165/165) | Acc 77.6407 | BAcc 38.5189 | Loss 0.0039 | Total 0:00:47 | ETA 0:00:01

Epoch: [4 | 90] LR: 0.000100
Training |################################| Batch (1311/1311) | Acc 97.6261 | BAcc 38.5228 | Loss 0.0004 | Total 0:07:08 | ETA 0:00:01
Validating |################################| Batch (165/165) | Acc 77.6407 | BAcc 38.5189 | Loss 0.0039 | Total 0:00:45 | ETA 0:00:01

Epoch: [5 | 90] LR: 0.000100
Training |################################| Batch (1311/1311) | Acc 97.6261 | BAcc 38.5228 | Loss 0.0004 | Total 0:07:10 | ETA 0:00:01
Validating |################################| Batch (165/165) | Acc 77.6407 | BAcc 38.5189 | Loss 0.0038 | Total 0:00:47 | ETA 0:00:01

It seems you are working on a multi-label classification with 40 labels (based on the number of loss functions you are creating).
Inside the model you are reusing the same nn.Sequential block (self.fc) for all tasks. Is this intended or would you like to initialize own layers for each task?

Hi @ptrblck, thank you so much for your help. Every word in your reply counts. Now my multi-task NN converges. What i’ve done is inspired by the valuable clues you provided: change the structure of task layers to make them different from each other.

the new model arch. after the change:

class FC(nn.Module):
    def __init__(self, bitsPerAttr=1, ):
        super(FC, self).__init__()
        self.layers = nn.Sequential(
            nn.Linear(1024, 512),
            nn.ReLU(),
            nn.Linear(512, bitsPerAttr),
        )

    def forward(self, x):
        output = self.layers(x)
        return output

class Alexnet2fc(nn.Module):
    def __init__(self, taskNum=40, bitsPerAttr=1, ):
        super(Alexnet2fc, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2, ),
            nn.Conv2d(64, 192, kernel_size=5, padding=2),
            nn.BatchNorm2d(192),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2, ),
            nn.Conv2d(192, 384, kernel_size=3, padding=1),
            nn.BatchNorm2d(384),
            nn.ReLU(inplace=True),
            nn.Conv2d(384, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2, ), 

            nn.Linear(2, 512),
            nn.ReLU(),
            nn.Linear(512, 1024),
        )

        self.avgpool = nn.AdaptiveAvgPool2d((2, 2)) 
        
       #  comment the devil out!!
        '''
        fc = nn.Sequential(
            nn.Linear(1024, 512),
            nn.ReLU(),
            nn.Linear(512, bitsPerAttr),
        )
        self.fc = fc
        self.towers = nn.ModuleList([self.fc for _ in range(taskNum)])
        '''
        self.towers = nn.ModuleList([FC() for _ in range(taskNum)])

    def forward(self, x):
        x = self.features(x) 
        x = self.avgpool(x) 
        x = torch.flatten(x, 1)
        out = [tower(x) for tower in self.towers]
        return out


You can’t imagine how excited I am when I see the result becomes normal. I can’t thank you too much.

1 Like

Good to hear it’s working and the suggestion helped! :slight_smile:

1 Like