Create ramifications recursively to multilabel/multitask network

Hi

I’m trying to train a multilabel network to extract tags from images. The number of tags is about to 250. In order to increase the accuracy I tough to use a base network as resnet familiy and then add different ramifications from the output of the resnet network, one for each tag. I know how to do it manually. The problem is do that for that huge amount of tags. I would like to find the way to do it recursively. I tried using lists and dictionaries, but didn’t work.

This is my code.


class MultiTaskNet(nn.Module):
    # n_tags = number of tags in the dataset
	def __init__(self, n_tags):
		super(MultiTaskNet, self).__init__()
		self.base_net = torchvision.models.resnet18(pretrained=True)
		num_ftrs = self.base_net.fc.in_features
		self.base_net.fc = nn.Linear(num_ftrs, 1024)

        #network defined assembled at the output of the base_net metwork
		self.tag_net = nn.Sequential(nn.Linear(1024, 1024),nn.PReLU(), nn.Linear(1024, 512), nn.PReLU(), nn.Linear(512, 256), nn.PReLU(), nn.Linear(256, 1) )

		self.tag_net_list = []
		for i in range (0,n_tags):
			self.tag_net_list.append(self.tag_net)
	
    def forward(self, x):
		response = []
		x = self.base_net(x)
		for tag_net in self.tag_net_list:
			x2 = tag_net(x)
			response.append(x2)
		return response

I think that the loss is not capable to calculate how improve the ramification weight on this structure.

Could you help me to do it in a proper way?

Thank you!

Hi,

By looking at your architecture, first question: Is there any specific reason to branch tag_net for every tag? Your tag_net takes input tensor of 1024 values and gives output tensor of only 1 value so instead of branching tag_net for every tag, you can just feed the output of resnet18 to a single tag_net and get output tensor of same values as your number of tags. I am guessing your image contains multiple tags so you have to use nn.BCEWithLogitsLoss function for calculating loss and pass hot-encoded vector as your target.

Take a look at this: https://gombru.github.io/2018/05/23/cross_entropy_loss/

Hope this helps.