Tracking calculations on NLLLoss

I’m trying to make a modified version of NLLLoss, but as i follow through the code, i see that it actually implements https://github.com/pytorch/pytorch/blob/master/torch/nn/functional.py#L666.

If i follow that path, i see that in executes _functions.thnn.NLLLoss.apply(input, target, weight, size_average, ignore_index), which in turn points to pytorch/pytorch/blob/master/torch/nn/_functions/thnn/auto.py#L36 on the repo.

I can follow the code up to there, i can start debugging and get to that point. The problem i have is that the last line, the one where the loss is actually calculated (i hope), is
getattr(ctx._backend, update_output.name)(ctx._backend.library_state, input, target, output, *ctx.additional_args)

There, i get utterly lost. I cannot find the function that is called dynamically, so i don’t know what the actual code is. Even if i try to “step in” with a debugger, i can’t. I see that it calls something called 'ClassNLLCriterion_updateOutput'. I can’t find that string in all of the code, so it is generated dynamically, but from what? where?

Is there a way to see what that built-in function does? Could someone point me to the code?

Thanks SO much!

Here’s the code it calls:

On CPU: https://github.com/pytorch/pytorch/blob/master/torch/lib/THNN/generic/ClassNLLCriterion.c

On GPU: https://github.com/pytorch/pytorch/blob/master/torch/lib/THCUNN/generic/ClassNLLCriterion.cu

1 Like

Thank You!! That helps a lot!

Was there a way i could follow through to get there? or is the meta-programming too deep?

Disclaimer: i’m a pythonista, but not an expert by any means

it’s totally not worth following through. eventually you could get there, and write a couple of blogposts like these:
http://pytorch.org/2017/05/11/Internals.html
http://pytorch.org/2017/06/27/Internals2.html

1 Like

Hi smth, I read the code of _functions.thnn.NLLLoss and what dose this line " ignore_index -= TH_INDEX_BASE" mean?
What the value of TH_INDEX_BASE ?

TH_INDEX_BASE is 0 for pytorch and 1 for Lua-Torch, it’s a compile-time constant that defines whether we are using zero-indexing or 1-indexing