Not all weights train with combined loss functions

Hello!

I’m trying to do something which is possibly a bit silly, involving training multiple things in the same neural net by adhering their loss functions. Here’s an example, in which a CNN learns image classification (on cifar10) and at the same time, a separate tensor, “noisyfriend” is told to decrease its norm:

import torch
import numpy as np
from torchvision import datasets
import torchvision.transforms as transforms
from torch.utils.data.sampler import SubsetRandomSampler
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
train_on_gpu=torch.cuda.is_available()




class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.noisyfriend = torch.normal(0, 1, size=(64, 64))
        self.noisyfriend.requires_grad_()
        if train_on_gpu:
            self.noisyfriend.cuda()
        
        self.conv1 = nn.Conv2d(3, 64, 3, padding = 1)
        self.fc1 = nn.Linear(64*32*32, 256)
        self.fc2 = nn.Linear(256, 10)
        self.dropout = nn.Dropout(p=.2)
    
    def forward(self, x):
        conv_out = F.relu(self.conv1(x))
        conv_out = conv_out.view(-1, 64*32*32)
        y = self.dropout(conv_out)
        fc1_out = F.relu(self.fc1(y))
        z = self.dropout(fc1_out)
        fc2_out = self.fc2(z)
        noisyfriend = torch.linalg.norm(self.noisyfriend)
        return fc2_out, noisyfriend
    
    
    
    
model = Model()

if train_on_gpu:
    model.cuda()
    


optimizer = optim.SGD(model.parameters(), lr=.001)
criterion = nn.CrossEntropyLoss()

n_epochs = 50


valid_loss_min = np.Inf

for epoch in range(n_epochs):
    train_loss = 0.0
    valid_loss = 0.0
    
    # Model training
    model.train()
    for data, target in train_loader:
        if train_on_gpu:
            data, target = data.cuda(), target.cuda()
        optimizer.zero_grad()
        output = model(data)
        
                
        loss = criterion(output[0], target) + output[1]
        loss.backward()
        
        optimizer.step()
        
        
        train_loss += loss.item()*data.size(0)
        
    # Model eval
    model.eval()
    for data, target in valid_loader:
        if train_on_gpu:
            data, target = data.cuda(), target.cuda()
        output = model(data)[0]
        loss = criterion(output, target)
        valid_loss += loss.item()*data.size(0)
        
    train_loss = train_loss/len(train_loader.sampler)
    valid_loss = valid_loss/len(valid_loader.sampler)
    
    print('Epoch: {} \tNoisy friend norm: {:.6f} \tCNN Loss: {:.6f}'.format(
        epoch, torch.linalg.norm(model.noisyfriend).item(), valid_loss))
    

In short, our loss function is the cross entropy loss of the CNN components plus the (matrix) norm of a tensor which is otherwise unrelated to the CNN. What should happen is the CNN trains, and our tensor noisyfriend slowly becomes the 0 tensor. The CNN trains fine (or about as well as you’d expect from something so bare-bones!), but our noisyfriend stays resolutely at its exact, unchanged values. Here’s some output data:

Epoch: 0 	Noisy friend norm: 63.723152 	CNN Loss: 1.765747
Epoch: 1 	Noisy friend norm: 63.723152 	CNN Loss: 1.658773
Epoch: 2 	Noisy friend norm: 63.723152 	CNN Loss: 1.591563
Epoch: 3 	Noisy friend norm: 63.723152 	CNN Loss: 1.525593

I have tried everything I could think of, which might make the code a bit inelegant—double checking that gradients were on, and moving the norm computation to the forward function, just in case that would help. However, my changes have been in vain.

I know that my example is quite silly, but I am incredibly confused. Any help is greatly appreciated!

Your noisy friend is not properly registered as a parameter and will thus not be passed to the optimizer via model.parameters():

print(dict(model.named_parameters()).keys())
> dict_keys(['conv1.weight', 'conv1.bias', 'fc1.weight', 'fc1.bias', 'fc2.weight', 'fc2.bias'])

Use self.register_parameter or assign self.noisyfriend to an nn.Parameter and it will work:

self.noisyfriend = nn.Parameter(torch.normal(0, 1, size=(64, 64)))

This allows you also to remove the train_on_gpu condition as self.noisyfriend will be automatically moved to the device in the model.to() call.

1 Like