Loss oscillating and not minimizing

I am training a simple attention network with stored extracted ResNet features. Every gigapixel image is divided into approximately 20000 patches of size 256x256, and each patch is associated with a feature vector from custom ResNet50. Now, my shape of data for every image will be [20000, 1024].
The train data loader loads a gigapixel image at a time, making the batch size 1.
Reference: https://github.com/mahmoodlab/CLAM

class Attn_Net_Gated(nn.Module):
    def __init__(self, L = 1024, D = 256, dropout = False, n_classes = 1):
      super(Attn_Net_Gated, self).__init__()
        self.attention_a = [
            nn.Linear(L, D),
        self.attention_b = [nn.Linear(L, D),
        if dropout:
        self.attention_a = nn.Sequential(*self.attention_a)
        self.attention_b = nn.Sequential(*self.attention_b)
        self.attention_c = nn.Linear(D, n_classes) #Linear(L,D); L - input, D - output

    def forward(self, x):
        a = self.attention_a(x)
        b = self.attention_b(x)
        A = a.mul(b) 
        A = self.attention_c(A)  
        return A, x 

class MB(nn.Module):
    def __init__(self, gate = True, size_arg = "small", dropout = False, k_sample=8, n_classes=2,
        instance_loss_fn=nn.CrossEntropyLoss(), subtyping=True):
        self.size_dict = {"small": [1024, 512, 256], "big": [1024, 512, 384]} #choosing the model size 
        size = self.size_dict[size_arg] 
        fc =[]
        if gate:
            attention_net = Attn_Net_Gated(L = size[0], D = size[2], dropout = dropout, n_classes = n_classes) 
        self.attention_net = nn.Sequential(*fc)
        self.n_classes = n_classes
        self.subtyping = subtyping
def relocate(self):
        device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
    def forward(self, h, label=None, instance_eval=False, return_features=False, attention_only=False):
        device = h.device
        A, h = self.attention_net(h)         
        A = torch.transpose(A, 1, 0)  
        if attention_only:
            return A, h


def get_split_loader(split_dataset, training = False, testing = False, weighted = False):
		return either the validation loader or training loader 
	kwargs = {'num_workers': 4} if device.type == "cuda" else {}
	if not testing:
		if training:
			if weighted:
				weights = make_weights_for_balanced_classes_split(split_dataset)
				loader = DataLoader(split_dataset, batch_size=1, sampler = WeightedRandomSampler(weights, len(weights)), collate_fn = collate_MIL_tr, **kwargs)	
				loader = DataLoader(split_dataset, batch_size=1, sampler = RandomSampler(split_dataset), collate_fn = collate_MIL_tr, **kwargs)
			loader = DataLoader(split_dataset, batch_size=1, sampler = SequentialSampler(split_dataset), collate_fn = collate_MIL_tr, **kwargs)
		ids = np.random.choice(np.arange(len(split_dataset), int(len(split_dataset)*0.1)), replace = False)
		loader = DataLoader(split_dataset, batch_size=1, sampler = SubsetSequentialSampler(ids), collate_fn = collate_MIL_tr, **kwargs )

	return loader
def get_optim(model, args):
	if args.opt == "adam":
		optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr, weight_decay=args.reg)
	elif args.opt == 'sgd':
		optimizer = optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr, momentum=0.9, weight_decay=args.reg)
		raise NotImplementedError
	return optimizer
def initialize_weights(module):
	for m in module.modules():
		if isinstance(m, nn.Linear):
		elif isinstance(m, nn.BatchNorm1d):
			nn.init.constant_(m.weight, 1)
			nn.init.constant_(m.bias, 0)


def train(datasets, cur, args):
    train_split, val_split, test_split = datasets
    save_splits(datasets, ['train', 'val', 'test'], os.path.join(args.results_dir, 'splits_{}.csv'.format(cur)))
    model = MB(**model_dict, instance_loss_fn=instance_loss_fn)     
    optimizer = get_optim(model, args)
    train_loader = get_split_loader(train_split, training=True, testing = args.testing, weighted = args.weighted_sample)
    val_loader = get_split_loader(val_split,  testing = args.testing)
    test_loader = get_split_loader(test_split, testing = args.testing)
    if args.early_stopping:
        early_stopping = EarlyStopping(patience = 20, stop_epoch=50, verbose = True)
        early_stopping = None
    for epoch in range(args.max_epochs):
        if args.model_type in [''mmb'] and not args.no_inst_cluster:     
            epoch_loss = train_loop(epoch, model, train_loader, optimizer, args.n_classes, args.bag_weight, writer, loss_fn)
            stop, val_loss = validate(cur, epoch, model, val_loader, args.n_classes, early_stopping, writer, loss_fn, args.results_dir)     
        if stop:  
    if args.early_stopping:
        model.load_state_dict(torch.load(os.path.join(args.results_dir, "s_{}_checkpoint.pt".format(cur))))
        torch.save(model.state_dict(), os.path.join(args.results_dir, "s_{}_checkpoint.pt".format(cur)))
    return epoch_loss, val_loss

def train_loop(epoch, model, loader, optimizer, n_classes, bag_weight, writer = None, loss_fn = None):
    device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
    epoch_loss = 0.
    for batch_idx, (data, label, coordinates, slide_id) in enumerate(loader): 
        data, label = data.to(device), label.to(device)
        pred_val, h_feat = model(data, label = label, attention_only = True)
        target_val = #this is an array that I obtain from some interpolation and has the same shape of pred_val
        l2_loss = coeff * torch.nn.functional.mse_loss(pred_val.unsqueeze(0), target_val.unsqueeze(0))
        epoch_loss +=  l2_loss.item()
    epoch_loss = epoch_loss / len(loader)
    print('Epoch: {}, train_loss: {:.4f} '.format(epoch, epoch_loss))
    return epoch_loss

The train loss oscillates and gets stuck within a fixed range of values as follows and does not minimize:

NOTE: I have tried the above experiments for Learning rates ranging from 1e-2 to 1e-6; Weight decay from 1e-3 to 1e-6; for optimizers both Adam and SGD; epochs from 50 to 200 (with and without early stopping). The loss graph for all the experiments conducted so far is similar to the above snapshot.

Any help is appreciated.

Try to scale down the use case and overfit a small subset of the dataset. Since you are already dealing with a lot of patches created by a single sample you might want to try to overfit your model to this single gigapixel image first by playing around with some hyperparameters.

Thank you for your suggestion. I tried overfitting the model with one gigapixel image with 30000 patches. Upon closer investigation, I noticed that the weights aren’t updated during the training at all.
Keeping the above reference code, I checked the model parameters as follows for every epoch.

        a = list(model.parameters())[0].clone()
        b = list(model.parameters())[0].clone()
        print(torch.equal(a.data, b.data))
        for name, param in model.named_parameters():
        	print(name, param.grad, param.requires_grad)


attention_net.0.attention_a.0.weight None True
attention_net.0.attention_a.0.bias None True
attention_net.0.attention_a.1.weight None True
attention_net.0.attention_a.1.bias None True
attention_net.0.attention_b.0.weight None True
attention_net.0.attention_b.0.bias None True
attention_net.0.attention_b.1.weight None True
attention_net.0.attention_b.1.bias None True
attention_net.0.attention_c.weight None True
attention_net.0.attention_c.bias None True

It seems like the optimizer is not updating the weights at all and no gradient is computed. Do you have any advice @ptrblck ?

@ptrblck [UPDATE]: I caught the problem in the training process. The computational graph was broken in the training loop. This was hampering the gradient propagation, and I solved that issue. Sometimes, an unforeseen and seemingly trivial bug is annoying.

Great to hear you’ve isolated the issue!
Which operation was detaching the computation graph?

def train_loop(epoch, model, loader, optimizer, n_classes, bag_weight, writer = None, loss_fn = None):
    device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
    epoch_loss = 0.
    for batch_idx, (data, label, coordinates, slide_id) in enumerate(loader): 
        data, label = data.to(device), label.to(device)
        pred_val, h_feat = model(data, label = label, attention_only = True)
        target_val = #this is an array that I obtain from some interpolation and has the same shape of pred_val
        l2_loss = coeff * torch.nn.functional.mse_loss(pred_val.unsqueeze(0), target_val.unsqueeze(0))
        epoch_loss +=  l2_loss.item()
    epoch_loss = epoch_loss / len(loader)
    print('Epoch: {}, train_loss: {:.4f} '.format(epoch, epoch_loss))
    return epoch_loss

The target_val had some non-differentiable operations while calculating the interpolation. I performed a workaround that helped in avoiding disruptions during the computation.
Thank you.

That’s strange as the target value does not need to be attached to a computation graph and can be static. Are you sure this was causing the issue?

It looks like that. The above fixed worked for a single image. I will try it on a subset and get back to you.

[UPDATE] @ptrblck
For a single gigapixel image,

For 30 gigapixel images,

I will have to run some experiments on different learning rates and optimizers. But I can still see some oscillations. However, the model parameters are updated, and the loss is decreasing.

@ptrblck , I have updated the above comment.