Detaching one of the operands of a tensor multiplication from a computation graph causes "RuntimeError: Function 'MulBackward0' returned nan values in its 1th output."

Hi all,

Can anyone help me out with how to avoid the error in the subject line? I built a function, which is called inside the forward pass at every iteration, like:

def psudoe_fn(input):
    feat_vec = generate_feat_vec(input)
    density, aux = predict_density(feat_vec)  # ~ (batch, 1)
    distr = predict_distr(feat_vec)
    return density * distr 

The function above itself worked nicely. What I’m trying to do now is detaching density from the computation graph. I modified the line density, aux = predict_density(feat_vec) to:

    with torch.no_grad():
        density, aux = predict_density(feats)


    with torch.autograd.set_grad_enabled(False):
        density, aux = predict_density(feats)

But, when I run the code with torch.autograd.detect_anomaly(), I get the MulBackward0 at return density * distr for both modifications.

I have checked both density and distr do not have nan or inf inside with .isan().any() and .isinf().any(). How can I resolve the error?

P.S. The output of psudoe_fn is further processed later, including calculations of kl_div and the total variation for the sake of regularization. The neural network is optimized with adam augmented with Nvidia apex library.