class network(nn.Module):
def __init__(self):
super(network, self).__init__()
self.linear1 = nn.Linear(in_features=40, out_features=50)
self.bn1 = nn.BatchNorm1d(num_features=50)
self.linear2 = nn.Linear(in_features=50, out_features=25)
self.bn2 = nn.BatchNorm1d(num_features=25)
self.linear3 = nn.Linear(in_features=25, out_features=1)
def forward(self, input): # Input is a 1D tensor
out_list = []
y = F.relu(self.bn1(self.linear1(input)))
out_list.append(y)
y = F.relu(self.bn2(self.linear2(y)))
out_list.append(y)
y = self.linear3(y)
out_list.append(y)
return out_list
model = network()
x = torch.randn(10, 40)
output_list = model(x)
Let’s say that I have the target for each one of the outputs in the output_list.
Each one of the output will have a loss associated. Can I run loss.backward() for each one without update variables that didn’t contribute to the loss? Like, the loss of the first output of the list shouldn’t change the weights of the last linear.
If I sum all the loss and run total_loss.backward(), the loss of the first linear will affect the update of the last linear? I want to total_loss.backward() doesn’t affect other variables without contribution to ‘loss’.