Loss.backward() not updating model parameters

Hi, I’m training this model GitHub - microsoft/CLAP: Learning audio concepts from natural language supervision (a CLIP-like model that train an audio encoder and a text encoder at the same time using contrastive loss). However, somehow the backward step cannot update the encoders’ parameters.

Here’s the code for the loss function:

import torch
from torch import Tensor
from torch import nn
import torch.nn.functional as F

def contrastive_loss(E_a: Tensor, E_t: Tensor, temperature: float = 0.5, device="cpu") -> Tensor:
    sum_term = 0
    batch_size = len(E_a)
    N = range(batch_size)

    for i in N:
        pos = torch.exp(F.cosine_similarity(E_a[i], E_t[i], dim=-1) / temperature)
        a_t_neg = 0
        t_a_neg = 0

        for j in N:
            a_t_neg = a_t_neg + torch.exp(F.cosine_similarity(E_a[i], E_t[j], dim=-1) / temperature)
            t_a_neg = t_a_neg + torch.exp(F.cosine_similarity(E_t[i], E_a[j], dim=-1) / temperature)

        a_t = torch.log(pos / a_t_neg)
        t_a = torch.log(pos / t_a_neg)
        sum_term = sum_term - (a_t + t_a)
    
    loss = 1 / (2*batch_size) * sum_term
    loss.to(device)
    return loss

class ContrastiveLoss(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, input: Tensor, target: Tensor, temperature: float) -> Tensor:
        return contrastive_loss(input, target, temperature)

Training loop:

from msclap.models.clap import AudioEncoder, TextEncoder, Projection, CLAP

audio_encoder = AudioEncoder(
    audioenc_name="HTSAT",
    d_in=768,
    d_out=1024,
    sample_rate=16000,
    window_size=1024,
    hop_size=320,
    mel_bins=64,
    fmin=50,
    fmax=8000,
    classes_num=527
)

audio_encoder.requires_grad_(True)

text_encoder = TextEncoder(
    text_model="gpt2",
    d_out=1024,
    transformer_embed_dim=768
)
text_encoder.requires_grad_(True)

print("=================")

audio_optimizer = torch.optim.Adam(audio_encoder.parameters(), lr=0.001)
text_optimizer = torch.optim.Adam(text_encoder.parameters(), lr=0.001)

loss_function = ContrastiveLoss()
# loss_function = nn.CrossEntropyLoss()

use_device = "cpu"
epochs = 1
batch_size = 5
limit = 5

audio_encoder.to(device=use_device)
text_encoder.to(device=use_device)

epoch_avg_losses = []

text_encoder.train()
audio_encoder.train()
data_loader = DataLoader(dataset, batch_size=5)

for epoch in range(epoches):

    current_losses = []
    indices = tqdm(range(0, limit, batch_size), desc=f"Epoch: {epoch}")

    for audio_tensor, text_dict_raw in data_loader:
        # subsets = dataset[i: i+batch_size]
        text_input = {
            "input_ids": text_dict_raw["input_ids"].reshape(batch_size, -1), 
            "attention_mask": text_dict_raw["attention_mask"].reshape(batch_size, -1)}

        audio_optimizer.zero_grad()
        text_optimizer.zero_grad()

        audio_embeded, _ = audio_encoder(audio_tensor.reshape(batch_size, -1))
        text_embedded = text_encoder(text_input)

        loss_val = loss_function(audio_embeded, text_embedded)
        current_losses.append(loss_val.item())
        
        loss_val.backward(retain_graph=True)
        audio_optimizer.step()
        text_optimizer.step()        
        indices.set_postfix({"loss_val": loss_val.item()})
                         
    epoch_avg_losses.append(sum(current_losses) / len(current_losses))

I suspect that I did something wrong in my loss function so I tested this with the default CrossEntropyLoss, but the two encoder’s parameters were not updated either. I

Is it possible that the problems lie in the models’ code itself? Really appreciate any pointers!

Could you print the .grad attribute of the parameter in question before and after the backward call?

Here’s the output. I run the training loop for 4 epochs on the same 4 samples.
The audio and text models have an encoder layer and a linear projection layer. Param 0 is the first param of the encoder layer, and param -4 is the first param of the projection layer.

The audio encoder is a HTSAT model, and the text encoder is GPT2.

Epoch: 0: 
====== audio, before
Grad param [0]: 
 None
Grad Param [-4]: 
 None
===== text, before
Grad param [0]: 
 None
Grad param [-4]: 
 None
====== audio, after backward
Grad param [0]: 
 None
Grad Param [-4]: 
 tensor([[ 0.0405, -0.0440, -0.0065,  ..., -0.0460,  0.0224, -0.0993],
        [-0.0389,  0.0331,  0.0211,  ...,  0.0342, -0.0255,  0.1370],
        [ 0.0455, -0.0086,  0.0186,  ...,  0.0344,  0.0122, -0.0025],
        ...,
        [-0.0632,  0.0321, -0.0272,  ..., -0.0262, -0.0011,  0.0038],
        [-0.0293,  0.0159, -0.0083,  ..., -0.0185,  0.0104,  0.0102],
        [ 0.0223, -0.0216, -0.0092,  ..., -0.0107, -0.0010, -0.0638]],
       device='cuda:0')
===== text, after backward
Grad param [0]: 
 tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]], device='cuda:0')
Grad param [-4]: 
 tensor([[-0.0060,  0.0006, -0.0095,  ..., -0.0043, -0.0034, -0.0021],
        [-0.0035,  0.0014, -0.0003,  ..., -0.0003,  0.0003, -0.0011],
        [ 0.0035, -0.0012,  0.0033,  ...,  0.0015,  0.0014,  0.0012],
        ...,
        [-0.0041, -0.0005, -0.0136,  ..., -0.0050, -0.0051, -0.0025],
        [-0.0036,  0.0016, -0.0035,  ..., -0.0025, -0.0026, -0.0001],
        [-0.0006, -0.0042, -0.0136,  ..., -0.0050, -0.0035, -0.0013]],
       device='cuda:0')

Epoch: 1
====== audio, before
Grad param [0]: 
 None
Grad Param [-4]: 
 None
===== text, before
Grad param [0]: 
 None
Grad param [-4]: 
 None
====== audio, after backward
Grad param [0]: 
 None
Grad Param [-4]: 
 tensor([[-0.0047,  0.0628, -0.0458,  ...,  0.0042, -0.0308,  0.0293],
        [-0.0257, -0.0219,  0.0169,  ..., -0.0322,  0.0028,  0.0342],
        [-0.0083, -0.0375,  0.0251,  ..., -0.0147,  0.0122,  0.0014],
        ...,
        [ 0.0198, -0.0528,  0.0408,  ...,  0.0127,  0.0268, -0.0604],
        [ 0.0051,  0.0356, -0.0302,  ...,  0.0120, -0.0180,  0.0075],
        [-0.0148,  0.0355, -0.0187,  ..., -0.0128, -0.0170,  0.0292]],
       device='cuda:0')
===== text, after backward
Grad param [0]: 
 tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]], device='cuda:0')
Grad param [-4]: 
 tensor([[-9.7118e-04,  1.3865e-03, -3.3962e-04,  ...,  1.4537e-04,
         -2.0846e-06, -4.0756e-04],
        [-6.3074e-04,  5.9609e-04, -3.2964e-04,  ...,  1.4668e-04,
         -1.7538e-04,  1.6203e-04],
        [ 1.9745e-04, -4.5461e-06,  2.4286e-04,  ...,  1.4766e-04,
         -2.2812e-04,  4.5129e-05],
        ...,
        [-1.5923e-04,  3.9380e-04,  5.8437e-06,  ...,  1.3212e-05,
          4.8233e-05, -2.6066e-04],
        [ 2.5247e-04, -1.2888e-04,  3.2683e-04,  ...,  1.6269e-04,
          1.1461e-05, -5.7793e-05],
        [ 6.7609e-05, -5.5542e-04, -2.9274e-04,  ..., -2.6170e-04,
          1.4302e-04,  2.7578e-04]], device='cuda:0')

Epoch: 2
Grad param [0]: 
 None
Grad Param [-4]: 
 None
===== text, before
Grad param [0]: 
 None
Grad param [-4]: 
 None
====== audio, after backward
Grad param [0]: 
 None
Grad Param [-4]: 
 tensor([[ 0.0062, -0.0027,  0.0092,  ...,  0.0002,  0.0111,  0.0088],
        [ 0.0101, -0.0037,  0.0086,  ..., -0.0085,  0.0152,  0.0025],
        [-0.0105,  0.0039, -0.0144,  ...,  0.0017, -0.0190, -0.0110],
        ...,
        [-0.0053,  0.0020, -0.0068,  ...,  0.0013, -0.0091, -0.0056],
        [ 0.0073, -0.0035,  0.0103,  ..., -0.0007,  0.0127,  0.0094],
        [ 0.0238, -0.0086,  0.0224,  ..., -0.0158,  0.0359,  0.0132]],
       device='cuda:0')
===== text, after backward
Grad param [0]: 
 tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]], device='cuda:0')
Grad param [-4]: 
 tensor([[ 1.9479e-03, -1.2798e-03,  2.3455e-03,  ...,  9.3851e-04,
          6.2090e-04,  1.2598e-03],
        [-9.3708e-04, -1.5941e-04, -2.9056e-03,  ..., -9.7923e-04,
         -8.2801e-04, -9.7548e-04],
        [ 4.8224e-04,  1.5746e-04,  1.5780e-03,  ...,  5.2570e-04,
          4.3323e-04,  5.3249e-04],
        ...,
        [-1.1908e-04, -3.8168e-05, -6.1555e-04,  ..., -2.2111e-04,
         -1.9989e-04, -1.3724e-04],
        [-3.6051e-03,  1.6148e-03, -5.6189e-03,  ..., -2.2264e-03,
         -1.3088e-03, -2.6084e-03],
        [-8.9797e-04,  1.5050e-03,  1.6164e-03,  ...,  5.8679e-04,
          4.1812e-04, -1.9751e-04]], device='cuda:0')

Epoch: 3
====== audio, before
Grad param [0]: 
 None
Grad Param [-4]: 
 None
===== text, before
Grad param [0]: 
 None
Grad param [-4]: 
 None
====== audio, after backward
Grad param [0]: 
 None
Grad Param [-4]: 
 tensor([[-0.0042, -0.0046, -0.0006,  ...,  0.0042, -0.0002,  0.0065],
        [ 0.0189,  0.0079,  0.0082,  ..., -0.0107,  0.0106, -0.0067],
        [ 0.0067,  0.0083,  0.0016,  ..., -0.0072, -0.0019, -0.0171],
        ...,
        [ 0.0090,  0.0037,  0.0044,  ..., -0.0050,  0.0043, -0.0056],
        [ 0.0061, -0.0016,  0.0033,  ..., -0.0007,  0.0081,  0.0105],
        [-0.0109,  0.0059, -0.0067,  ..., -0.0005, -0.0174, -0.0262]],
       device='cuda:0')
===== text, after backward
Grad param [0]: 
 tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]], device='cuda:0')
Grad param [-4]: 
 tensor([[-9.9368e-04, -1.4957e-04, -1.7117e-03,  ..., -8.0239e-04,
         -2.9693e-04, -1.2470e-03],
        [ 3.3229e-03, -7.6360e-04,  3.1208e-03,  ...,  1.6842e-03,
          5.6865e-04,  2.8112e-03],
        [-2.5695e-03,  3.0420e-04, -3.1424e-03,  ..., -1.5772e-03,
         -5.2567e-04, -2.4934e-03],
        ...,
        [-4.9629e-04,  2.3430e-04, -2.6720e-04,  ..., -1.5403e-04,
         -4.8559e-05, -3.2742e-04],
        [ 3.0467e-03, -6.0293e-04,  3.1757e-03,  ...,  1.6730e-03,
          5.4095e-04,  2.6710e-03],
        [-4.2881e-03,  1.5830e-03, -2.7884e-03,  ..., -1.6225e-03,
         -5.6317e-04, -3.1121e-03]], device='cuda:0')


Losses: [2.7047977447509766, 4.15026330947876, 2.312796115875244, 5.260274887084961]


Thank you for the outputs! I assume your concern is regarding audio’s None gradient of param[0]?
If so, could you post the model definition?

Thansk for the quick reply!

This is the encoder definition directly from their repo:

class Projection(nn.Module):
    def __init__(self, d_in: int, d_out: int, p: float=0.5) -> None:
        super().__init__()
        self.linear1 = nn.Linear(d_in, d_out, bias=False)
        self.linear2 = nn.Linear(d_out, d_out, bias=False)
        self.layer_norm = nn.LayerNorm(d_out)
        self.drop = nn.Dropout(p)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        embed1 = self.linear1(x)
        embed2 = self.drop(self.linear2(F.gelu(embed1)))
        embeds = self.layer_norm(embed1 + embed2)
        return embeds

class AudioEncoder(nn.Module):
    def __init__(self, audioenc_name:str, d_in: int, d_out: int, sample_rate: int, window_size: int,
            hop_size: int, mel_bins: int, fmin: int, fmax: int, classes_num: int) -> None:
        super().__init__()

        audio_encoder = get_audio_encoder(audioenc_name)

        self.base = audio_encoder(
            sample_rate, window_size,
            hop_size, mel_bins, fmin, fmax,
            classes_num, d_in)

        self.projection = Projection(d_in, d_out)

    def forward(self, x):
        out_dict = self.base(x)
        audio_features, audio_classification_output = out_dict['embedding'], out_dict['clipwise_output']
        projected_vec = self.projection(audio_features)
        # print("audio projected_vec requires_grad=", projected_vec.requires_grad)
        return projected_vec, audio_classification_output

class TextEncoder(nn.Module):
    def __init__(self, d_out: int, text_model: str, transformer_embed_dim: int) -> None:
        super().__init__()
        self.text_model = text_model
        self.base = AutoModel.from_pretrained(text_model)

        if 'clip' in text_model:
            self.clip_text_projection = self.base.text_projection
            self.base = self.base.text_model
            if 'base' in text_model:
                transformer_embed_dim = 512
        
        self.projection = Projection(transformer_embed_dim, d_out)

    def forward(self, x):
        if 'clip' in self.text_model:
            pooled_output = self.base(**x)[1] # get pooled output
            out = self.clip_text_projection(pooled_output)  # get CLS token output
        elif 'gpt' in self.text_model:
            batch_size = x['input_ids'].shape[0]
            hidden_states = self.base(**x)[0] # (batch_size=4, seq_len, 768)

            sequence_lengths = torch.ne(x['input_ids'], 0).sum(-1) - 1 # tensor([13, 14, 18, 17])
            out = hidden_states[torch.arange(batch_size, device=hidden_states.device), sequence_lengths] # [batch_size, 768] = [4, 768]
        else:
            out = self.base(**x)[0]
            out = out[:, 0, :]  # get CLS token output
        
        projected_vec = self.projection(out)

        return projected_vec

The HTSAT model to encode audio is defined here: CLAP/msclap/models/htsat.py at main · microsoft/CLAP · GitHub

The text encoder simply load the pretrained GPT2 model.

And sorry for not being clear. For more context, initially I run the training loop for only 1 iteration, with batch size = 5, on 100 samples. I did not see the batches’ loss decrease much - they all hover around a certain point (depending on the temperature parameter of the loss function). This leads me to wonder if I made a mistake in the loss function, or there is something in the model that prevents the parameters to be updated.

Now when I printed out the grad of parameter 0 and -4, I see that param 0 has no grad, while param -4’s grad changed.

I also noticed that is param 0 is not updated, while param -4 is updated after .step()

Just an update: after alternating learning rate and batch size, and testing on a smaller dataset of 300 samples, I have seen the loss reducing after each epochs. I guess that means the code works as expected (unless you spot any errors in the code pieces I posted).