I am just starting to learn Pytorch, and trying to build a custom loss function
def __init__(self, in_features = 10, out_features = 1):
self.in_features = in_features
self.out_features = out_features
self.center = torch.nn.Parameter(torch.randn(in_features))
self.center.requires_grad = True
def forward(self, input):
return torch.norm(input - self.center, 2)
Then I want to use it in some model:
model = SomeModel()
loss = distance_loss()
scores = model(torch.rand(1,3,32,32))
optimizer = torch.optim.SGD(model.parameters(), lr = 0.001 )
But this returns error: “distance_loss” object has no attribute ‘backward’. It is my understanding that backward method is automatically implemented in nn.module. What goes wrong here? Also do I somehow need to append the learnable parameters of distance_loss() to model.parameters() for optimizer to know?
.backward() is done once you have a computation graph created; where is your computation? You can see a small use case here
>>> x = torch.tensor([[1., -1.], [1., 1.]], requires_grad=True)
>>> out = x.pow(2).sum()
tensor([[ 2.0000, -2.0000],
[ 2.0000, 2.0000]])
on the page https://pytorch.org/docs/stable/tensors.html .
Sorry, I missed one line in the second part of the code and now I have edited it.
I think you have two things called loss here that are confusing.
loss in your code is actually the loss module you defined. You need to call it with the input to get the loss value from your batch. Then you will be able to call
.backward() on that.
Hi. Thank you for answering!
I changed it to something like this:
model = SomeModel()
x = torch.rand(1,3,32,32)
scores = model(x)
loss_func = distance_loss(10)
loss = loss_func.forward(scores)
but x.grad now prints an empty list
x = torch.rand(1,3,32,32) with
x = torch.rand(1,3,32,32, requires_grad=True).
Thank you!! Is there a way that I can integrate the gradient of the loss into the optimizer so that it will perform gradient descent on the learnable parameters of my cusom defined loss function together with the parameters from model?