I know many have answered this issue but my problem is slightly different which I could not find an answer for.
Basically, I am creating my own custom loss (cost) function, however, it uses the output of the MLP as a means to calculate the actual loss in the system - the MLP output is not directly transferred in the loss function. I have coded a very simplified example of my problem for your understanding.
My intuition is that my current implementation is breaking the computational graph and thus the optimizer is not updating the weights. So my question is how could this be modified in order to make sure the computational graph does not break down and the weights update.
Please let me know if you need more explanation(s) of my problem. Thank you in advance!
Below is the sample code:
import torch.nn as nn
import torch.nn.functional as F
import torch
class MLP(nn.Module):
def __init__(self, input_dim, output_dim, activation, bias):
super().__init__()
self.input_fc = nn.Linear(input_dim, 200, bias=bias)
self.hidden_fc = nn.Linear(200, 150, bias=bias)
self.output_fc = nn.Linear(150, output_dim, bias=bias)
self.activation = activation
def forward(self, x):
h_1 = self.activation(self.input_fc(x))
h_2 = self.activation(self.hidden_fc(h_1))
y_pred = self.output_fc(h_2)
y_prob = F.softmax(y_pred, dim=1)
return y_prob
def cost(y_prob, target):
torch.manual_seed(10)
cost_array = torch.rand(1,output_dim) # random cost array
idx_max = torch.argmax(y_prob)
# To simplify, the cost is chosen with respect to the node with highest probability
cost = torch.tensor(cost_array[0,idx_max]-target, requires_grad=True)
return cost
torch.manual_seed(10) # just to have same result every run
input_dim, output_dim = 10, 9
activation = F.relu
bias = True
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
Classifier = MLP(input_dim, output_dim, activation, bias).to(device)
learning_rate = 1e-3
optimizer = torch.optim.Adam(Classifier.parameters(), lr=learning_rate)
numEpochs = 10
input_mlp = torch.rand(1,input_dim).to(device)
target = torch.tensor(0).to(device) # target value is set to zero
for epoch in range(numEpochs):
# using weight norm to check if the weights are changing
weight_norm = 0
for param in Classifier.parameters():
weight_norm += torch.norm(param,2).item()
y_prob = Classifier(input_mlp)
loss = cost(y_prob, target)
print('Weight norm: {:.10f}, Loss = {:.4f}'.format(weight_norm, loss.item()))
optimizer.zero_grad()
loss.backward()
optimizer.step()