issue_loss_fn = nn.CrossEntropyLoss()
product_loss_fn = nn.CrossEntropyLoss()
issue_loss = issue_loss_fn(issue_logits, issue_y)
product_loss = product_loss_fn(pro_logits, issue_y)
loss = issue_loss + product_loss
loss.backward()
is this the correct way to do it? Or am I doing anything wrong?
when one label result improves other degrades
epoch: 0 | loss: 1.705373126316846 | issue acc: 0.45827687443541104 | product acc: 0.5626646717856068 | time elapsed: 143.8780710697174
------------------------------------------------------------------------------------------
epoch: 0 | loss: 1.223887874321504 | issue acc: 0.630754662004662 | product acc: 0.7407852564102565 | time elapsed: 152.7476155757904
------------------------------------------------------------------------------------------
epoch: 1 | loss: 1.1055800391406547 | issue acc: 0.6752625338753387 | product acc: 0.6483316395663957 | time elapsed: 146.28524684906006
------------------------------------------------------------------------------------------
epoch: 1 | loss: 1.1996180794455789 | issue acc: 0.663224796037296 | product acc: 0.6807255244755245 | time elapsed: 155.0542027950287
------------------------------------------------------------------------------------------
epoch: 2 | loss: 0.8235846665816579 | issue acc: 0.7612635501355014 | product acc: 0.6567487202649804 | time elapsed: 146.46119809150696
------------------------------------------------------------------------------------------
epoch: 2 | loss: 1.1156758611852473 | issue acc: 0.7048186188811189 | product acc: 0.6636800699300699 | time elapsed: 155.3083529472351
------------------------------------------------------------------------------------------
epoch: 3 | loss: 0.6182250726998337 | issue acc: 0.8257772508280637 | product acc: 0.6497760463715748 | time elapsed: 146.60942959785461
------------------------------------------------------------------------------------------
epoch: 3 | loss: 1.4946378902955488 | issue acc: 0.700247668997669 | product acc: 0.638567162004662 | time elapsed: 155.38528728485107
------------------------------------------------------------------------------------------
any suggestions
second output for every epoch is for validation set