THNN combined loss and grad computation

I’ve been working on porting CTC loss to THNN and got confused by the sparse documentation out there.

The particular implementation I’m basing my work on computes loss and gradients in a single operation (or rather doesn’t even calculate an absolute loss). The easiest way to fit this into THNN would be to run the whole alignment twice, once in updateOutput and once in updateGradInput but that is obviously twice as slow as doing it once.

Some old documentation mentions that updateGradInput has to run after updateOutput which I interpret that there is some way data is shareable between both functions. All other *Criterion.c implementations just calculate the gradient directly from input and target arguments and I am not even pretending to understand the preprocessor hell that is most of THNN.

If there’s someone knowledgeable of THNN around I’d also like to know the amount of error checking functions are expected to perform. Some do this rather extensively, others only basic sanity checks.

One way to do this is to run the whole thing twice, as you mentioned, once in updateOutput and once in updateGradInput.

There is a way to share data between the two functions. You can define a python module that has forward() and backward() functions. The forward() function can all into the THNN updateOutput, and then save the result to a context. backward() is able to access this context, and you wouldn’t need a THNN updateGradInput in this case.
See here for more details about this.

The more error checking, the better. Errors that occur in TH that aren’t checked very well can result in some hard to parse error messages (or problems) for the end user.

I am aware that shimming the loss into a python class is one way to fix this; in fact all other CTC pytorch bindings do it this way.

As wrapping THNN functionality in python classes inside pytorch seems to be the legacy way of doing it (assumed from all the code in legacy/ doing it), is it realistic to get such a wrapped loss function merged into mainline? The point of this whole exercise is to have CTC loss directly in pytorch without having to compile extra extensions and so on.