Exploding gradients of the multilabel problem

I’m currently trying to train a model with ignite to get multiple labels l1 and ``l2````. My current problem is that I do not see my loss converging / exploding gradients. My first attempt was that I have a wrong interpretation of the cross entropy loss since they except raw logits instead of probabilities but that doesn’t seem to be the problem.
Here is my code:

import numpy as np
import random
import torch.nn.functional as Fun
from torch.optim import SGD, Adam
#from models.resnet import *
from methods.weight_methods import *
from torchvision import transforms, datasets
from ignite.metrics import Precision, Recall
from ignite.contrib.handlers import ProgressBar
from ignite.utils import convert_tensor
from ignite.engine import Engine
from ignite.metrics import Accuracy
from ignite.engine.events import Events
import os
import sys

from experiments.multimnist.dataset import *
from experiments.multimnist.models import *
#from methods.sgd import SGD

data_ = {'11': {'train': torch.randn(1000,100,1), 'test': torch.randn(600,100,1), 'val': torch.randn(200,100,1)},
       '22': {'train': torch.randn(1000,100,1), 'test': torch.randn(600,100,1), 'val': torch.randn(200,100,1)}}

aux_var_ = {
    '1': {'train': torch.randn(1000,1), 'test': torch.randn(600,1), 'val': torch.randn(200,1)},
    '2': {'train': torch.randn(1000,1), 'test': torch.randn(600,1), 'val': torch.randn(200,1)},
    '3': {'train': torch.randn(1000,1), 'test': torch.randn(600,1), 'val': torch.randn(200,1)}

labels_ = {
    'l1': {'train': torch.randint(0,3, (1000,)), 'test': torch.randint(0,3, (1000,)), 'val': torch.randint(0,3, (1000,))} ,
    'l2': {'train': torch.randint(0,3, (1000,)), 'test': torch.randint(0,3, (1000,)), 'val': torch.randint(0,3, (1000,))},

# hyperparameters
epochs = 100
penalize = False
batch_size = 1000
num = 8

class Model1(nn.Module):
    def __init__(self):
        self.conv1d1 = torch.nn.LazyLinear(32)
        self.conv1d2 = torch.nn.LazyLinear(16)
        self.relu1 = torch.nn.LeakyReLU()
        self.flatten = torch.nn.Flatten()
        self.fc1 = nn.LazyLinear(3)
        self.fc2 = nn.LazyLinear(5)
    def forward(self, x):      
        x = self.conv1d1(x)
        x = self.conv1d2(x)
        x = self.relu1(x)
        x = self.flatten(x)
        y1 = F.softmax(self.fc1(x))
        y2 = F.softmax(self.fc2(x))
        return [y1,y2]
    def shared_parameters(self) -> Iterator[nn.parameter.Parameter]:
        return chain(

    def task_specific_parameters(self) -> Iterator[nn.parameter.Parameter]:
        return chain(

    def last_shared_parameters(self) -> Iterator[nn.parameter.Parameter]:
        return self.flatten.parameters()

model = Model1().to('cpu')

class MyDataset(Dataset):
    def __init__(self, x, xaux, y, ds='train'):
        self.amp = x['11'][ds]
        self.phase = x['22'][ds]
        self.xaux  = xaux
        self.shelf_day = y['l1'][ds]
        self.sugar_content = y['l2'][ds]
        #self.aroma = y['aroma'][ds]
    def __len__(self):
        return len(self.shelf_day)
    def __getitem__(self, idx, flip_rate=.5):

        return (
            torch.tensor(self.shelf_day[idx], dtype=torch.int),
            torch.tensor(self.sugar_content[idx]-1, dtype=torch.int)

ds_train = MyDataset(data_, aux_var_, labels_, 'train')
train_loader = DataLoader(ds_train, batch_size=batch_size, num_workers=0, shuffle=True)

optimizer = SGD(model.parameters(), lr=0.01, momentum=0.9)

def accuracy_func(predictions, labels):
    classes = torch.Tensor([torch.argmax(pred, dim=0) for pred in predictions], device='cpu')
    return torch.mean((classes == labels).float())

import gc
def train_step_wo_smooth(engine, batch):
    device = 'cpu'
    # metric definition
    metric = Accuracy()

    batch = convert_tensor(batch, device=device, non_blocking=True)
    x1_ = batch[0].to(device)
    y1_ = batch[2].to(device)
    y2_ = batch[3].to(device)
    y_pred = model(torch.Tensor(x1_).to(torch.float32).to(device))    
    y1_ = torch.Tensor(y1_).type(torch.LongTensor).to(device)

    losses1 = F.cross_entropy(y_pred[0], y1_)
    extra_outputs = {}
    #grads = [model.lin1.weight.grad, model.lin1.weight.grad]
    #weights = [model.lin1.weight, model.lin2.weight]

    #acc1 = Accuracy(output_transform=lambda out: out['y_1']).attach(trainer, 'acc1')
    #acc2 = Accuracy(output_transform=lambda out: out['y_2']).attach(trainer, 'acc2')
    acc1 = accuracy_func(y_pred[0], y1_)
    acc2 = accuracy_func(y_pred[1], y1_)
    return {
        'y_1': (y_pred[0], y1_),
        'y_2': (y_pred[1], y2_),
        'loss': (losses1.detach().cpu().numpy()),
        'acc1': acc1,
        'acc2': acc2


if __name__ == '__main__':
    trainer = Engine(train_step_wo_smooth)    
    def log_training(engine):
        batch_loss = engine.state.output['loss']
        #lr = optim.param_groups[0]['lr']
        e = engine.state.epoch
        n = engine.state.max_epochs
        i = engine.state.iteration
        acc = engine.state.output['acc1']
        print(f"Epoch {e}/{n} : {i} - batch loss: {batch_loss} - accuracy: {acc}")

    loss1, loss2 = [], []
    def log_training_loss(engine):
        # print(engine.state.metrics)
        epoch = engine.state.epoch
        iteration = engine.state.iteration
    trainer.add_event_handler(Events.EPOCH_COMPLETED, log_training_loss)
    state = trainer.run(train_loader, max_epochs=epochs)

This is indeed wrong and I would recommend removing the softmax calls since internally F.log_softmax will be applied.
Besides that, have you tried to lower the learning rate as it’s a common root cause for exploding losses?