When i tried to reimplement MSELoss proposed here Loss Pytroch Implementation with a real example, it gives RuntimeError: derivative for argmax is not implemented although i take argmax with nn.MSELoss and worked.
Code def my_MSELoss(predict, true):
return ((predict - true)**2).mean()
for epoch in range(5):
losses = 0.0
for data, label in train_loader:
network.zero_grad()
predict = network(data.unsqueeze(1).float())
print(predict.shape)
#predict = predict.argmax(axis = 1)
loss = my_MSELoss(predict.argmax(axis = 1), label.float())
print(loss)
loss.backward()
optimizer.step()
loss = loss.item()
losses += loss
print(“For epoch: {}, losses are: {}”.format(epoch + 1, losses))
Which PyTorch version are you using? If I remember correctly, this bug was recently fixed by @albanD in this PR and I get an error using your my_MSELoss
function as:
RuntimeError: derivative for argmax is not implemented
I’m Sorry, but i don’t know how to A dd argmin
, argmax
and argsort
to the list of non-differentiable functions
For torch version it’s 1.5 running on Google Colab
The idea is that none of your methods should work, as the derivative of argmax
is not implemented.
It seems to be a bug that one of your approaches seems to have worked.
When i tried using max() it throw another error
Code is:
def my_MSELoss(predict, true):
return ((predict - true)**2).mean()
for epoch in range(5):
losses = 0.0
for data, label in train_loader:
#.float() to convert data into weights type
#you can know type of weights through:
#next(model.conv1.parameters()).dtype
network.zero_grad()
predict = network(data.unsqueeze(1).float())
print(predict.shape)
#predict = predict.argmax(axis = 1)
_, y_preds = predict.max(axis = 1)
print(y_preds.shape)
loss = my_MSELoss(y_preds, label.float())
print(loss)
loss.backward()
optimizer.step()
loss = loss.item()
losses += loss
print("For epoch: {}, losses are: {}".format(epoch + 1, losses))
Output is:
RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn
y_preds
is still the argmax
. You could get the gradients for the first output (which is the max value).
In that case max
would be differentiable and the gradients would just flow back to the maximum value while all other values will get a 0 gradient.
Argmax
on the other hand is not differentiable as it return integer values. So you can’t get gradients for it.
1 Like