Hi, I would like to use model a to predict another model b(LeNet)'s parameters, so I need to implement a loss function myself
def custom_loss_function_new(output_batch, features, targets):
batch_loss = 0.0
features = features.to(DEVICE)
targets = targets.to(DEVICE)
for output in output_batch:
output1 = output.clone()
test_model = LeNet5(NUM_CLASSES, GRAYSCALE).to(DEVICE)
optimizer1 = torch.optim.Adam(test_model.parameters(), lr=LEARNING_RATE1)
unflatten_parameters(output, test_model) # override LeNet's parameters
test_model.train()
logits, probas = test_model(features)
cost1 = F.cross_entropy(logits, targets)
cost1 = cost1.requires_grad_()
print("COST1:", cost1)
optimizer1.zero_grad()
cost1.backward()
for name, param in test_model.named_parameters():
if param.requires_grad:
print(name, param.grad)
optimizer1.step()
loss = F.mse_loss(output1, flatten_parameters(test_model))
# print("loss1:",loss)
batch_loss += loss
average_loss = batch_loss / len(output_batch)
return average_loss
above is my loss function and it works well, how ever I want to define my own backforward function in the autograd, so I define a torch.autograd.Function class
class My_class_loss(torch.autograd.Function):
@staticmethod
def forward(ctx, output_batch, features, targets):
average_loss = custom_loss_function_new(output_batch, features, targets)
return average_loss
@staticmethod
def backward(ctx, grad_output):
grads, = ctx.saved_tensors
return grads, None, None
and use below code to calculate loss
cost = custom_loss_function_new(generated_data, features, targets) #1
cost = My_class_loss.apply(generated_data, features, targets) #2
when I use custom_loss_function_new, it just print the correct param.grad
but when I use My_class_loss, I just got None for grad
COST1: tensor(3.3397, requires_grad=True)
features.0.weight None
features.0.bias None
features.3.weight None
features.3.bias None
classifier.0.weight None
classifier.0.bias None
classifier.2.weight None
classifier.2.bias None
classifier.4.weight None
classifier.4.bias None
I have no idea why I get different results in a same function