How to forward dynamic layers properly?

Hi, so I am working on a custom CNN module which needs to forward some inputs through several dynamic layers. This module has three loss functions which need to be back-propagated.

The following is a snippet of my network module:

class ProtoModule(nn.Module):
    def __init__(self,num_classes,num_proto):
        super(ProtoModule,self).__init__()
        ...
        self.proto_layers = nn.ModuleList()

        for c in range(self.num_classes):
            for p in range(self.num_proto):
                self.proto_layers.append(
                    nn.Sequential(
                        nn.Conv2d(chan,chan,kernel_size=(3,3), stride=(1, 1)),
                        nn.ReLU(inplace=True)
                    )
                )

        self.fc = nn.Linear(self.num_classes * self.num_proto, self.num_classes)
    
    def forward(self, x):
        ...
        proto = Variable(torch.Tensor(),requires_grad=True)
        
        for proto_layer in self.proto_layers:
            res = proto_layer(x)
            res = F.max_pool2d(res, kernel_size=res.size()[2:])
            res = max_pool_channel(res,res.size()[1])
            proto = torch.cat((proto,res))

        proto = proto.reshape(curr_batch_size, self.num_classes * self.num_proto)
        x = self.fc(proto)

        return x, proto

And here is how I currently back-propagate the loss:

...
with torch.set_grad_enabled(phase == 'train'):
                output, proto_output = model(input)
                _, preds = torch.max(output, 1)
                
                loss1 = cross_entropy_loss(output, labels)
                loss2 = cluster_loss(proto_output, labels)
                loss3 = separation_loss(proto_output, labels)
                
                total_loss = loss1 + 0.5 * loss2 + 0.5 * loss3
                
                if phase == 'train':
                    total_loss.backward()
                    optimizer.step()
...

The problem is my network won’t converge and I think this has something to do with the forward through the dynamic layers. Did I do the forwarding correctly?

Please advise. Thanks!