Inception-resnet-v2 always gives same prediction

I am trying to train a timm implementation of an inception-resnet-v2 model on the NIHChestXRay dataset. The issue is that after 6 hours of 2x GPU accelerated training, the model only learns to predict the same nonsense output for every image.

I am trying to do multi-label classification over 14 labels, one-hot encoded
The images are originally grayscale but I have tripled them so that they fit the RGB model. Size is (3,1024,1024)
I have cleaned the dataset so that it contains around 50k images.

Training loop is below

def training_function(mixed_precision="fp16", seed:int=42, batch_size:int=96):
    set_seed(seed)
    
    accelerator = Accelerator(mixed_precision=mixed_precision)
    model = timm.create_model('inception_resnet_v2', pretrained=False, num_classes=14)
    config = resolve_data_config({}, model=model)
    transform = create_transform(**config)


    train_dataset = NIHChestXRays("/kaggle/input/data", transform=transform)
    eval_dataset = NIHChestXRays("/kaggle/input/data", train=False, transform=transform)

    train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=batch_size, num_workers=2)
    eval_dataloader = DataLoader(eval_dataset, shuffle=False, batch_size=batch_size*2, num_workers=2)

    for param in model.parameters(): 
        param.requires_grad=False
    for param in model.get_classifier().parameters():
        param.requires_grad=True
        
    mean = torch.tensor(model.default_cfg["mean"])[None, :, None, None]
    std = torch.tensor(model.default_cfg["std"])[None, :, None, None]
    
    mean = mean.to(accelerator.device)
    std = std.to(accelerator.device)
    
    optimizer = torch.optim.Adam(params=model.parameters(), lr = 3e-2/25)
    
    lr_scheduler = OneCycleLR(
        optimizer=optimizer, 
        max_lr=3e-2, 
        epochs=5, 
        steps_per_epoch=len(train_dataloader)
    )
  
    model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.prepare(
        model, optimizer, train_dataloader, eval_dataloader, lr_scheduler
    )
    total_loss = 0    
    for epoch in range(8):
        model.train()
        criterion = torch.nn.BCEWithLogitsLoss()
        for step, batch in enumerate(train_dataloader):
            inputs = (batch["image"] - mean) / std
            outputs = model(inputs)
            loss = criterion(outputs, batch["label"])
            accelerator.backward(loss)
            optimizer.step()
            lr_scheduler.step()
            optimizer.zero_grad()
            total_loss += loss.detach().float()

        model.eval()
        accurate = 0
        num_elems = 0
       

        for step, batch in enumerate(eval_dataloader):
            inputs = (batch["image"] - mean) / std
            with torch.no_grad():
                outputs = model(inputs)
                
            predictions = torch.sigmoid(outputs)
            threshold = 0.5
            one_hot_output = (predictions > threshold).float()
            
            predictions = accelerator.gather(one_hot_output)
            targets = accelerator.gather(batch["label"])
            
            accurate_preds = predictions == targets
            num_elems += accurate_preds.shape[0]
            accurate += accurate_preds.long().sum()

        eval_metric = accurate.item() / num_elems
        accelerator.print(f"epoch {epoch}: {100 * eval_metric:.2f}")
        
        accelerator.log(
                {
                    "accuracy": 100 * eval_metric,
                    "train_loss": total_loss.item() / len(train_dataloader),
                    "epoch": epoch,
                },
                step=step,
            )
       
    accelerator.wait_for_everyone() 
   
    model = accelerator.unwrap_model(model)
    
    accelerator.save(model, "model.pth")
    accelerator.save_state("state")
    accelerator.end_training()


args = ("fp16", 42, 16)
notebook_launcher(training_function, args, num_processes=2)

I’m unsure if there is anything wrong or I should just add more epochs and hope for the best.
This is an example of what the model is predicting.

model = torch.load("/kaggle/input/inception-resnet-v2-trained/results/model.pth")
config = resolve_data_config({}, model=model)
transform = create_transform(**config)

eval_dataset = NIHChestXRays("/kaggle/input/data", train=False, transform=transform)
eval_dataloader = DataLoader(eval_dataset, shuffle=False, batch_size=4, num_workers=2)

for batch in eval_dataloader:
    print(batch["label"])
    out = model(batch["image"].float().to("cuda:0"))
    break
    
print(out)
tensor([[0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 1., 0., 0., 0., 0., 0.]])
tensor([[-1.1811, -3.2016, -2.6149, -3.6426, -1.2084, -3.6994, -3.9469, -6.2628,
         -0.2163, -2.4023, -1.8594, -2.6814, -3.9607, -2.7347],
        [-1.1811, -3.2016, -2.6149, -3.6426, -1.2084, -3.6994, -3.9469, -6.2628,
         -0.2163, -2.4023, -1.8594, -2.6814, -3.9607, -2.7347],
        [-1.1811, -3.2016, -2.6149, -3.6426, -1.2084, -3.6994, -3.9469, -6.2628,
         -0.2163, -2.4023, -1.8594, -2.6814, -3.9607, -2.7347],
        [-1.1811, -3.2016, -2.6149, -3.6426, -1.2084, -3.6994, -3.9469, -6.2628,
         -0.2163, -2.4023, -1.8594, -2.6814, -3.9607, -2.7347]],
       device='cuda:0', grad_fn=<AddmmBackward0>)

I’m a bit confused by

    for param in model.parameters(): 
        param.requires_grad=False
    for param in model.get_classifier().parameters():
        param.requires_grad=True

which would suggest that you are only training the final layer of the model. Note that as you are loading a model with pretrained=False (which would randomly initialize the weights of your model), it would be similar to trying to fit the classifier layer on random input.

However, this should still not cause the model to be “stuck” on its predictions where each class is outputting the same value every time. I would check that you are not inadvertently wiping out your input data in the training or eval loop with the scaling operation (e.g., clamping it to all zeros or 255, etc.).

I have checked that

predictions = torch.sigmoid(outputs)
threshold = 0.5
one_hot_output = (predictions > threshold).float()

Could theoretically make the prediction [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] if all the predicted values are less than the threshold (which is probably happening everytime).
I have seen that in some cases the prediction is not higher than 0.3, so the prediction is clamped to all zeros

Should I lower the threshold or change that part of the eval loop so that it chooses the maximum value of the prediction tensor and makes it 1?

Also, regarding your comment about

for param in model.parameters(): 
        param.requires_grad=False
    for param in model.get_classifier().parameters():
        param.requires_grad=True

I’m usure what this code does since I copied it from an example provided by the Accelerate github repo. Should I remove that part and set pretrained=True or just set pretrained=True?

The part that is concerning me is that it looks like

    out = model(batch["image"].float().to("cuda:0"))

(which I assume to be the logits) appear to be exactly the same value every time. I don’t think this should occur unless something is wiping out the input to the model or some intermediate layer of the model is wiping things out (such as a normalization layer with incorrect statistics). I think it would be useful to rule out the former, which is why I suggested inspecting what the inputs being passed to the model are, and if they are indeed different.

In the long run if you are intending to train the entire model you should at least remove the param.requires_grad assignments, and experiment with whether starting from a pretrained model or training from scratch yields better accuracy.

There are a few things that could be wrong and it’s a little hard to tell without more context.

  1. What does the training loss curve look like? Does the loss go down?
  2. What do the distribution of probabilities look like? If your batches are super imbalanced wrt label then it’s possible the learned threshold is much less than 0.5
  3. What is the distribution of your training labels? If they are mostly one class then you need to account for this by balancing the batches by class label

Also a small tip - you don’t need to repeat the greyscale along all 3 channels dims. You can just do image = image.unsqueeze(0) which will make your image go from (1024, 1024) to (1, 1024, 1024) which is just an image with one channel :slight_smile:

Anyways, let me know and I can try and help out