How to split backpropegation for two parameters given one image

Hi,

I am trying to estimate two parameters, such as the length and angle of an object, from a given image using an EfficientNet. To achieve this, I split the output of the EfficientNet, which has 1280 classes, into two dense layers with 320 labels each. One dense layer is for the angle and the other for the length.

During training, I apply the Cross-Entropy (CE) loss separately to each parameter, sum them, and then divide by 2. After that, I perform backpropagation.

The training process is working fine, but I am wondering if there is a smarter way to train each parameter while maintaining this concept. Is it possible to perform backpropagation on each dense layer separately with its own loss and then combine them using the chain rule to backpropagate through the EfficientNet block?

Here is how the two dense networks are connected to the output of the EfficientNet:

class Model(nn.Module):
    def __init__():
        # bla bla
        eff_model.classifier = nn.Sequential(
            nn.Linear(in_features=1280, out_features=self.eff_b1_out_features),
            nn.LeakyReLU()
        )
        self.model = nn.Sequential(eff_model, net_add)  # add_net is a model containing two dense layers with a 1280-vector input

    def forward(self, matrix):
        len, angle = self.model(matrix)
        return len, angle

The training procedure follows these steps:

len_out, angle_out = self.model(matrix)
len_loss = criterion(len_out, len_label)
angle_loss = criterion(angle_out, angle_label)
total_loss = (len_loss + angle_loss) / 2
total_loss.backward()

Any suggestions on how to improve the training process would be greatly appreciated.
Thank you!

This sounds as if you are treating a regression use case as a multi-class classification where e.g. each possible angle corresponds to a class label.
If so, I would expect your model could overfit to a few angles as I would also expect to see an imbalance in the dataset or did you guarantee each angle has the same frequency in the data?
Did you try to output a single value for the angle and length each and apply e.g. nn.MSELoss on it?

1 Like