No convergence in a multi-task learning problem

I’m trying to train a model using multi-task learning. For this I defined some random data to first check if the loss is converging. Unfortunately the loss remains constant. For this I assume that I’m trying to output two tasks (assuming image data of shape (1,28,28) and process it a convolutional network (hard-parameter sharing approach). Further, I used pytorch-ignite.


import torch
from torch import nn
import torch.nn.functional as Fun
from torch.optim import SGD, Adam


from torchvision import transforms, datasets
from ignite.metrics import Precision, Recall
from ignite.contrib.handlers import ProgressBar
from ignite.utils import convert_tensor
from ignite.engine import Engine
from ignite.metrics import Accuracy
from ignite.engine.events import Events

from torch.utils.data import Dataset, DataLoader
import os
import sys



epochs = 20
device = 'cuda'



class Multitask(nn.Module):

  def __init__(self, n_class):
    super().__init__()

    self.n_class = n_class
    self.features = nn.Sequential(
      nn.Conv2d(1, 16, 3, padding=1),
      nn.ReLU(inplace=True),
      nn.Dropout(0.2),
      nn.Conv2d(16, 32, 3, padding=1),
      nn.ReLU(inplace=True),
      nn.Dropout(0.2),
      nn.Conv2d(32, 32, 3, padding=1, stride=2),
      nn.ReLU(inplace=True),
      nn.Dropout(0.2),
      nn.Conv2d(32, 64, 3, padding=1),
      nn.ReLU(inplace=True),
      nn.Dropout(0.2),
      nn.Conv2d(64, 64, 3, padding=1),
      nn.ReLU(inplace=True),
      nn.Dropout(0.2),
      nn.Conv2d(64, 64, 3, padding=1, stride=2),
      nn.ReLU(inplace=True),
      nn.Dropout(0.2),
      nn.Conv2d(64, 64, 3),
      nn.ReLU(inplace=True),
      nn.Dropout(0.2),
      nn.Conv2d(64, 64, 1),
      nn.ReLU(inplace=True),
      nn.Dropout(0.2),
    )
    self.classifier = nn.Conv2d(64, self.n_class, 1)
    
    self.avg_pool = nn.AvgPool2d(kernel_size=(2,2))
    
    self.flatten = nn.Flatten()
    
    self.lin1 = nn.LazyLinear(self.n_class)
    self.lin2 = nn.LazyLinear(self.n_class)
    
  def forward(self, x):
    x = self.features(x)
    x = self.classifier(x)

    #avg_pool = nn.AvgPool2d((x.size(-2), x.size(-1)), stride=(x.size(-2), x.size(-1)))
    avg_pool = self.avg_pool(x)#.view(2)
    #x = avg_pool(x).view(16, self.n_class)  # shape=(batch_size, n_class)
    
    x = self.flatten(avg_pool)
    
    x1 = torch.nn.functional.softmax(self.lin1(x))
    x2 = torch.nn.functional.softmax(self.lin1(x))
    return [x1, x2]

class MyDataset(Dataset):
    def __init__(self, x, y):
        self.x = x
        self.y1 = y[0]
        self.y2 = y[1]
        
        self.transform= transforms.Compose([
            transforms.ToPILImage(),
            transforms.Resize((28,40)),
            transforms.ToTensor()
        ])
    
        self.transform = transforms.Compose([transforms.ToPILImage(),
                                             transforms.Resize((28,56)),
                                             transforms.ToTensor(),
                                              transforms.Normalize([0.485],
                                                                   [0.229]),
                                              ])
    
    
    def __len__(self):
        return len(self.y1)
    
    def __getitem__(self, idx, flip_rate=.5):
        if self.transform is not None:
            self.x[idx] =self.transform(self.x[idx])
        
        
        # if random.random() < flip_rate:
        #   self.x[idx] = torch.flip(self.x[idx], [0,1])
        #   self.y1[idx], self.y2[idx] = self.y2[idx], self.y1[idx]
          
        self.y1[idx] = int(self.y1[idx])
        self.y2[idx] = int(self.y2[idx])
        
        return (
            torch.tensor(self.x[idx]),
            torch.tensor(self.y1[idx], dtype=torch.int),
            torch.tensor(self.y2[idx], dtype=torch.int)
        )
    
    
model = Multitask(10).to(device)


x_train = torch.randn(25000,1,28,56)
x_test = torch.randn(500,1,28,56)
y1_train = torch.randint(0,10, (25000,))
y2_train = torch.randint(0,10, (25000,))
y1_test = torch.randint(0,10, (500,))
y2_test = torch.randint(0,10, (500,))

ds_train = MyDataset(x_train, [y1_train, y2_train])
train_loader = DataLoader(ds_train, batch_size=16, num_workers=0, shuffle=True)

optimizer = Adam([
        dict(params=model.parameters(), lr=1e-3),
        #dict(params=method.parameters(), lr=1e-1),
    ],
        # penalize_norm=penalize,
        # iter=engine.state.epoch-1
    )
criterion = nn.CrossEntropyLoss()


def train_step(engine, batch):
    model.train()
    optimizer.zero_grad()


    batch = convert_tensor(batch, device=device, non_blocking=True)
    x1_ = batch[0].to(device)
    y1_ = batch[1].to(device)
    y2_ = batch[2].to(device)
  
    y_pred = model(torch.Tensor(x1_).to(device))

    y1_ = torch.Tensor(y1_).type(torch.LongTensor).to(device)
    y2_ = torch.Tensor(y2_).type(torch.LongTensor).to(device)
    y1_, y2_ = torch.nn.functional.one_hot(y1_, 10),  torch.nn.functional.one_hot(y2_, 10)
    
    losses1 = criterion(y_pred[0].float(), y1_.float())
    losses2 = criterion(y_pred[1].float(), y2_.float())
    loss = losses1 + losses2
    
    
    loss.backward()
    optimizer.step()
    extra_outputs = {}
    
    grads = [model.lin1.weight.grad, model.lin1.weight.grad]
    weights = [model.lin1.weight, model.lin2.weight]

    
    return {
        'loss1': losses1.detach().cpu().numpy(),
        'loss2': losses2.detach().cpu().numpy(),
        "total_loss": loss,
    }



trainer = Engine(train_step)



@trainer.on(Events.EPOCH_COMPLETED)
def log_training(engine):
    batch_loss = engine.state.output['total_loss']
    #lr = optim.param_groups[0]['lr']
    e = engine.state.epoch
    n = engine.state.max_epochs
    i = engine.state.iteration
    print(f"Epoch {e}/{n} : {i} - batch loss: {batch_loss} ")


ProgressBar(persist=True).attach(trainer, output_transform=lambda out: out)



loss = []
loss1, loss2 = [], []
@trainer.on(Events.ITERATION_COMPLETED)
def log_training_loss(engine):
    epoch = engine.state.epoch
    iteration = engine.state.iteration
    loss1.append(engine.state.output['loss1'])
    loss2.append(engine.state.output['loss2'])
    loss = engine.state.output['total_loss']


trainer.add_event_handler(Events.EPOCH_COMPLETED, log_training_loss)
trainer.run(train_loader, max_epochs=epochs)

I would be helpful for advice.

nn.CrossEntropyLoss expects raw logits while you are applying F.softmax on the model outputs and are thus creating probabilities. Remove the softmax calls and check if this helps in training the model.

thank you very much ptrblck. That was it.

1 Like