Updating the weights in a multioutput neural network

Hi,

I’m trying to create a multioutput / multihead feedforward neural network with some shared layers and two different heads.

Based on a condition, the forward step should split the features such that one group of samples should go into the one head and the other group of samples should go into the other head. I’m using indices to distinguish between these two groups.

Example code (I reduced the code to a minimum. Let me know if you need more information):

def forward(self, x):
    indices_t = (x[:, -1] == 1).nonzero()
    indices_c = (x[:, -1] == 0).nonzero()
    x = x[:, :-1]

    pred_t = None
    pred_c = None

    # Shared layers
    z1 = self.fc1(x)
    a1 = F.elu(z1)
    z2 = self.fc2(a1)
    a2 = F.elu(z2)
    ...

    x_t = torch.index_select(a6, 0, indices_t.flatten())
    x_c = torch.index_select(a6, 0, indices_c.flatten())

    # Head One
    if x_t.shape[0] > 0:
        z1_t = self.fc_t_1(x_t)
        a1_t = F.tanh(z1_t)
		...
        pred_t = F.softmax(a4_t, dim=1)

    # Head Two
    if x_c.shape[0] > 0:
        z1_c = self.fc_c_1(x_c)
        a1_c = F.tanh(z1_c)
        z2_c = self.fc_c_2(a1_c)
        a2_c = F.tanh(z2_c)
		...
        pred_c = F.softmax(a4_c, dim=1)

    return pred_t, pred_c

So far, the forward step works. The problem is the backward step / updating the weights.

Here, I read that I can simply add my two losses and compute the gradients. Unfortunately, weights of both heads are updated, even if I process only samples from one group. For example, let’s assume we have two different groups: X and Y. Further, let’s assume we process a single sample which belongs to group X. The idea is to use the loss, calculated in the forward step, to update the shared layers + the head belonging to group X but not the head belonging to group Y. Unfortunately, in my case, both layers are updated.

Here is an excerpt of the backward step:

for x, y in train_dl:

    pred_t, pred_c = model(x)
	
    indices_t = (x[:, -1] == 1).nonzero()
    indices_c = (x[:, -1] == 0).nonzero()

    y_t = torch.index_select(y, 0, indices_t.flatten())
    y_c = torch.index_select(y, 0, indices_c.flatten())
	
    if pred_t is not None:
        loss1 = F.cross_entropy(pred_t, y_t)
        if pred_c is not None:
            loss2 = F.cross_entropy(pred_c, y_c)
            # Case 1: Both losses could be calculated
            loss = loss1 + loss2
        else:
            # Case 2: Only one loss could be calculated
            loss = loss1
    else:
        # Case 3: Only one loss could be calculated
        loss = F.cross_entropy(pred_c, y_c)

    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

I also tried using requires_grad in the for loop on the layers which should not be updated. Unfortunately, the weights in all layers are nevertheless constantly updated

Thanks in advance.

The indexing of the input data and “selectively” calling backward works as expected in this code snippet:

class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.base = nn.Linear(1, 10)
        self.head1 = nn.Linear(10, 10)
        self.head2 = nn.Linear(10, 10)
        
    def forward(self, x):
        idx1 = (x == 0).nonzero()
        idx2 = (x == 1).nonzero()
        
        # base
        x = self.base(x)
        
        # heads
        x1 = torch.index_select(x, 0, idx1[:, 0])
        x2 = torch.index_select(x, 0, idx2[:, 0])
        
        x1 = self.head1(x1)
        x2 = self.head2(x2)
        
        return x1, x2

model = MyModel()
x = torch.randint(0, 2, (100, 1)).float()
out1, out2 = model(x)
print(out1.shape, out2.shape)

for name, param in model.named_parameters():
    print(name, param.grad)

out1.mean().backward(retain_graph=True)
for name, param in model.named_parameters():
    print(name, param.grad)

out2.mean().backward()
for name, param in model.named_parameters():
    print(name, param.grad)

As you can see, the gradients of the parameters of head2 will be None after the first backward call and will only be populated after out2.mean().backward() is called.

I guess the indexing in your code might not work as expected and might select all samples from the batch dimension. Could you verify that x_t and x_c contain separate sets of the input data and if so compare your approach to my example code snippet?

1 Like

You are right, I can see that the gradients are None after the first backward call. However, in some cases where the batch does only contain samples from one group, both heads are updated. I’m expecting that in such a case only the head gets updated which the group belongs to.

I noticed that in the first rounds of training the behavior is as expected. For example, imagine the following batches:

  • bs1:= x=1, y=1
  • bs2:= x=1, y=0
  • bs3:= x=0, y=1

Here, in the first training loop (using only bs1), only head1 is updated. The same applies for the second training loop. However, for the third training loop, both heads are updated even tough the gradients are zero for prediction1.

I modified your example such that you can convince yourself:

x = torch.randint(0, 2, (100, 3)).float()
for t in x:
    out1, out2 = model(t.view(1, 3))

    # Store weights to compare them after updating them
    head1_weight_buffer = model.head1.weight.detach().clone()
    head2_weight_buffer = model.head2.weight.detach().clone()

    for name, param in model.named_parameters():
         print(name, param.grad)

    out1.mean().backward(retain_graph=True)
    for name, param in model.named_parameters():
        print(name, param.grad)

    out2.mean().backward()
    for name, param in model.named_parameters():
        print(name, param.grad)

    optimizer_adam.step()
    optimizer_adam.zero_grad()

    # Compare weights. We are expecting that just one head changed. 
    print(torch.all(torch.eq(head1_weight_buffer, model.head1.weight)))
    print(torch.all(torch.eq(head2_weight_buffer, model.head2.weight)))

I’m expecting that one of the print statements is true and the other one is false, but they are never false or true at the same time.

Regarding your index question: It looks like your approach as well as my approach do work. I indeed got separate sets of input data. The only difference is the idx.flatten() instead of idx[:, 0], right?

I think I found a way to fix the issue. We need to distinguish between an empty prediction and a non-empty prediction. Further, we need to set the gradients to none instead to of zero. Here is what I did:

x = torch.randint(0, 2, (100, 3)).float()
for t in x:
    out1, out2 = model(t.view(1, 3))

    # Store weights to compare them after updating them
    head1_weight_buffer = model.head1.weight.detach().clone()
    head2_weight_buffer = model.head2.weight.detach().clone()

    # Check where prediction is non zero and only calculate the gradients for non-zero predictions
    if out1.shape[0] > 0:
        if out2.shape[0] > 0:
            out1.mean().backward(retain_graph=True)
            out2.mean().backward()
        else:
            out1.mean().backward()
    else:
        out2.mean().backward()

    optimizer_adam.step()
    # set_to_none = True
    optimizer_adam.zero_grad(set_to_none=True)

    # Compare weights. We are expecting that just one head changed.
    print(torch.all(torch.eq(head1_weight_buffer, model.head1.weight)))
    print(torch.all(torch.eq(head2_weight_buffer, model.head2.weight)))

It solved my problem! Heads are only getting updated if we were using samples from the specific head group. Can you confirm whether the code snippet above is correct? And why do we have to distinguish between non-zero and zero prediction?

I guess the issue you are seeing in the previous example is due to using an optimizer with internal states such as Adam.
Even if the gradients are zero for certain parameters, Adam might still update them if valid running estimates were already created.
By setting the gradients to None you are changing this behavior and Adam will skip the parameter updates for all parameters where .grad==None.

That sounds reasonable! However, without the if else statement, both heads are again updated, even though I’m using only an example from one of the groups. I think that is because of the retain_graph = True parameter in backward(). In the case where out1 is zero, I retain the graph which leads (somehow?) to an update of the weights in head1 and head2. Thank you very much for your help!