Hi, so I am working on a custom CNN module which needs to forward some inputs through several dynamic layers. This module has three loss functions which need to be back-propagated.
The following is a snippet of my network module:
class ProtoModule(nn.Module):
def __init__(self,num_classes,num_proto):
super(ProtoModule,self).__init__()
...
self.proto_layers = nn.ModuleList()
for c in range(self.num_classes):
for p in range(self.num_proto):
self.proto_layers.append(
nn.Sequential(
nn.Conv2d(chan,chan,kernel_size=(3,3), stride=(1, 1)),
nn.ReLU(inplace=True)
)
)
self.fc = nn.Linear(self.num_classes * self.num_proto, self.num_classes)
def forward(self, x):
...
proto = Variable(torch.Tensor(),requires_grad=True)
for proto_layer in self.proto_layers:
res = proto_layer(x)
res = F.max_pool2d(res, kernel_size=res.size()[2:])
res = max_pool_channel(res,res.size()[1])
proto = torch.cat((proto,res))
proto = proto.reshape(curr_batch_size, self.num_classes * self.num_proto)
x = self.fc(proto)
return x, proto
And here is how I currently back-propagate the loss:
...
with torch.set_grad_enabled(phase == 'train'):
output, proto_output = model(input)
_, preds = torch.max(output, 1)
loss1 = cross_entropy_loss(output, labels)
loss2 = cluster_loss(proto_output, labels)
loss3 = separation_loss(proto_output, labels)
total_loss = loss1 + 0.5 * loss2 + 0.5 * loss3
if phase == 'train':
total_loss.backward()
optimizer.step()
...
The problem is my network won’t converge and I think this has something to do with the forward through the dynamic layers. Did I do the forwarding correctly?
Please advise. Thanks!