I am trying to do the multitask cnn example on Fashion-MNIST dataset.
The multitask neural network only train two tasks simultaneously.
The result is only first task increase accuracy, the second task still remain almost the same accuracy.
I guess optimizer will optimize the first task first so the second task cannot increase accuracy.
But I still cannot confirm it, so I want to see the backward process of multitask neural network.
Is there any way to see the backward sequence of multitask neural network ?
The two task of Fashion-MNIST is two classification task
-
Task1 10 classes
|Label|Description|Label|Description|
|-----|-----|-----|-----|
|0|T-shirt/top |5|Sandal|
|1|Trouser |6|Shirt|
|2|Pullover |7|Sneaker|
|3|Dress |8|Bag|
|4|Coat |9|Ankle boot| -
Task2 3 classes
|Label|Original Labels|Description|
|-----|-----|-----|
|0|5, 7, 9|Shoes|
|1|3, 6, 8|For Women|
|2|0, 1, 2, 4|Other|
This is my network structure
class mtcnn(MtlFitModule):
def __init__(self):
super(mtcnn, self).__init__()
self.conv1 = nn.Conv2d(1, 32, kernel_size=(3, 3))
self.maxpool1 = nn.MaxPool2d(kernel_size=(2, 2), stride=2)
self.conv2 = nn.Conv2d(32, 64, kernel_size=(3, 3))
self.maxpool2 = nn.MaxPool2d(kernel_size=(2, 2), stride=2)
self.drop = nn.Dropout(p=0.5)
self.sub1 = nn.Linear(1600, 800)
self.sub2 = nn.Linear(1600, 800)
self.fc1 = nn.Linear(800, 10)
self.fc2 = nn.Linear(800, 3)
self.sub1.weight.requires_grad = False
self.sub1.bias.requires_grad = False
self.fc1.weight.requires_grad = False
self.fc1.bias.requires_grad = False
def forward(self, x):
# print(x.shape)
x = f.relu(self.conv1(x))
x = self.maxpool1(x)
# print(x.shape)
x = f.relu(self.conv2(x))
x = self.maxpool2(x)
x = x.view(x.shape[0], -1)
x = self.drop(x)
# print(x.shape)
sub2 = self.sub2(x)
y2 = self.fc1(sub2)
sub1 = self.sub1(x)
y1 = self.fc1(sub1)
return y1, y2
My loss and optimizer part
opt = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
loss_fn = CrossEntropyLoss().to(cuda0)
loss1 = loss_fn (output[0], batch_y1)
loss2 = loss_fn (output[1], batch_y2)
loss = loss1 + loss2
loss.backward()
opt.step()