Propagate gradients with two streams(two loss functions)

I’m newbie using Pytorch and trying to implement my idea using this toolkit.

I tried to create two back propagation stream with the outputs from the middle and end of the network.

for example, there is a network like

input -> conv1 -> conv2 -> conv3 -> fc1 -> fc2

Then, most of classifiers uses only output of fc2 layer.

But I want to use fc1's output and fc2's output simultaneously, the output of fc1 is for my custom loss function, and fc2's for ordinary cross entropy loss function.

I’ve implemented network returns output of fc1 and fc2 below:

class TestNet(nn.Module):

def __init__(self, num_classes=10):
    super(TestNet, self).__init__()
    self.features = nn.Sequential(
        nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=5),
        nn.MaxPool2d(kernel_size=2, stride=2),
        nn.Conv2d(64, 192, kernel_size=5, padding=2),
        nn.MaxPool2d(kernel_size=2, stride=2),
        nn.Conv2d(192, 384, kernel_size=3, padding=1),
        nn.Conv2d(384, 256, kernel_size=3, padding=1),
        nn.Conv2d(256, 256, kernel_size=3, padding=1),
        nn.MaxPool2d(kernel_size=2, stride=2),
    self.classifier1 = nn.Linear(256, 128)
    self.classifier2 = nn.Linear(128, num_classes)

def forward(self, x):
    x = self.features(x)
    x = x.view(x.size(0), -1)
    x1 = self.classifier1(x)
    x2 = self.classifier2(x1)
    return x1, x2

and in the training section, my code uses x1, x2 like below:
(criterion_c is cross entropy loss in nn, and criterion_g is my custom loss function.)

    x1, x2= net(inputs)        
    loss_c = criterion_c(x2, targets)
    loss_g = criterion_g(x1, vectors)

    loss = weight_g_loss * loss_g + loss_c

But during training, there is a problem.
Only gradient values of loss_g is back propagated, except loss_c's.
Only loss_g is getting decreased.
So I tested deleting loss_g, then loss_c is getting decreased.

So my question is,
Is there any way to make two stream of back propagation running well simultaneously?
I think that the two-way is the problem…


Your code looks fine. You’ll you check the values of both losses to see if maybe loss_c is a lot smaller than loss_g?