My problem is with this line(in python):
target_net.load_state_dict(policy_net.state_dict())
Because there isn’t a libtorch equivalent I used this workaround:
void loadstatedict(torch::nn::Module& model,torch::nn::Module& target_model)
{
torch::autograd::GradMode::set_enabled(false); // make parameters copying possible
auto new_params = target_model.named_parameters(); // implement this
auto params = model.named_parameters(true /*recurse*/);
auto buffers = model.named_buffers(true /*recurse*/);
for (auto& val : new_params)
{
auto name = val.key();
auto* t = params.find(name);
if (t != nullptr)
t->copy_(val.value());
else
{
t = buffers.find(name);
if (t != nullptr)
t->copy_(val.value());
}
}
torch::autograd::GradMode::set_enabled(true); //Set back
}
Everything seems to work but after long runs I usually get these from cuda during runtime:
C:\actions-runner\_work\pytorch\pytorch\builder\windows\pytorch\aten\src\ATen\native\cuda\ScatterGatherKernel.cu:145: block: [0,0,0], thread: [0,0,0] Assertion `idx_dim >= 0 && idx_dim < index_size && "index out of bounds"` failed.
CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call,so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
I have tried:
CUDA_LAUNCH_BLOCKING=1
torch::cuda::synchronize() before the copy
Moving both models to the cpu before copying.
Running only on CPU(It doesn’t SEEM to happen in when running only on the cpu.)
It feels like some problem syncing data between cpu/gpu. Everything runs fine for long runs until I start to do the copy(loadstatedict) and even then it seems like it works 99% of the time. If I disable the copy it never crashes.
C:\actions-runner\_work\pytorch\pytorch\builder\windows\pytorch\aten\src\ATen\native\cuda\ScatterGatherKernel.cu:145: block: [0,0,0], thread: [0,0,0] Assertion `idx_dim >= 0 && idx_dim < index_size && "index out of bounds"` failed.
Running the code with CUDA_LAUNCH_BLOCKING=1 should show you the line of code causing this issue. Once isolated you could print the indices and check why they contain wrong values.
Good to hear you were able to isolate the invalid indexing!
I’m not familiar enough with the code so don’t know which value should be used and if next_q_values is wrong or just needs clipping.