NaNs in net after backward while reinforcement learning

I have the following scenario: I’m trying to reinforcement learn in an existing C++ codebase, so I decided to use the C++ API. I’m getting a stream of gray-scale images as input and want 36 output values for every image (frame). I’m currently doing a forward pass for every frame and after an arbitrary number of frames I end the episode by passing in the rewards for each frame and doing some backward passes to train the network. Unfortunately I always at some point after the first episode get NaNs in my net. This is my code which I already heavily simplified but it still doesn’t work:

header file

class ReLeOptimizer
{
public:
    ReLeOptimizer();

    std::vector<double> processImpl(cv::Mat& inMat) override;
    void endEpisode(const std::vector<double>& rewards);

private:
    uint selectFromRange(torch::TensorAccessor<float, 1> accessor, uint begin, uint end);
    uint selectDirection(uint lampId, const at::Tensor& probabilities);
    uint selectDistance(uint lampId, const at::Tensor& probabilities);

private:
    // Define a new Module.
    struct Net : torch::nn::Module
    {
        Net();
        torch::Tensor forward(at::Tensor& x, at::Tensor& injection);

        torch::nn::Linear fc1{nullptr};
    } m_net;

    std::vector<torch::Tensor> m_policyHistory;

    torch::optim::Adam m_optimizer;
    cv::RNG m_rng;
};

cpp file

ReLeOptimizer::ReLeOptimizer()
    : m_net()
    , m_optimizer(m_net.parameters())
{
    m_optimizer.zero_grad();
    m_net.train();
    m_net.to(m_device);
}

std::vector<double> ReLeOptimizer::processImpl(cv::Mat& inMat)
{
    std::vector<double> output;

    // prepare the image tensor
    auto tensor =
        torch::from_blob(inMat.data, {inSize.height * inSize.width}, at::kByte).to(at::kFloat);

    // do the forward pass
    auto out = m_net.forward(tensor).set_requires_grad(true);

    std::vector<float> actions(36);
    std::fill(actions.begin(), actions.end(), 0.f);

    // here I "choose" 6 of the 36 values and mark the index of the chosen action with a 1.f in the action vector
    for (size_t i = 0; i < 6; ++i)
    {
        auto chosenId = chooseInRange(out, i*6, i*6+6);
        actions[chosenId] = 1.f;

        // push the result to output
        output.push_back(out[chosenId]);
    }

    // create the policy tensor and push it into the history
    long actionSize   = actions.size();
    auto actionTensor = torch::from_blob(actions.data(), {actionSize}).set_requires_grad(true);
    auto policyTensor = out.mul(actionTensor);

    m_policyHistory.push_back(policyTensor.sum(-1));

    return output;
}

void ReLeOptimizer::endEpisode(const std::vector<double>& rewards)
{
    double discountFactor = 0.99;
    uint episodeLength    = rewards.size();
    //torch::autograd::AnomalyMode::set_enabled(true);

    std::vector<float> discountedRewards(episodeLength);
    std::partial_sum(rewards.rbegin(), rewards.rend(), discountedRewards.rbegin(),
                     [discountFactor](double sum, double next) -> float { return (discountFactor * sum) + next; });

    // calculate reward mean and standard deviation
    float sum  = std::reduce(std::begin(discountedRewards), std::end(discountedRewards), 0.0);
    float mean = sum / static_cast<float>(rewards.size());
    float sd   = std::accumulate(std::begin(discountedRewards), std::end(discountedRewards), 0.0,
                               [mean](float a, float b) { return a + ((b - mean) * (b - mean)); });
    sd /= static_cast<float>(rewards.size());
    sd = std::sqrt(sd);

    std::transform(std::begin(discountedRewards), std::end(discountedRewards), discountedRewards.begin(),
                   [mean, sd](float f) { return (f - mean) / sd; });

    m_net.zero_grad();
    m_optimizer.zero_grad();
    for (uint framesProcessed = 0; framesProcessed < episodeLength; ++framesProcessed)
    {
        auto policy = m_policyHistory[framesProcessed].set_requires_grad(true);

        auto loss = torch::from_blob(&discountedRewards[framesProcessed], 1).mul(-1);

        policy.backward(loss);
        m_optimizer.step();
    }

    //  reset
    m_policyHistory.clear();
}

ReLeOptimizer::Net::Net()
{
    fc1 = register_module("fc1", torch::nn::Linear(64 * 64, 36));
}

at::Tensor ReLeOptimizer::Net::forward(at::Tensor& x)
{
    return torch::relu(fc1->forward(x));
}

I have been trying to pin this problem down for weeks now, but I don’t know what’s going on. If I enable the anomaly detection in endEpisode I hit an exception the backward call (stack trace shows torch::autograd::Engine::execute() as last call). If I don’t enable anomaly detection I just get a few NaNs in my parameters in fc1 right after the m_optimizer.step().

I’m sorry if it’s too much code, c++ tends to be a bit more long-winded than python I guess :D. Also, I roughly followed this tutorial if that helps: Policy Gradient Reinforcement Learning in PyTorch | by Tim Sullivan | Medium (unfortunately it’s in python but it was the closest thing I could find to what I need to do).