Implementing backward function nn.Module

Maybe this is useful for you?
I implemented a custom backward function for a ternary network (I read the BinaryNet paper that you cited before also), but to do this, I faced similar problems.
Eventually, breaking your function down into two pieces will solve your problems, too? At least in my case, it is working fine.