Advice on ways to save which data sample is causing a device-side assert error?

Hi. I’m trying to solve a device-side assert error. Within my dataset, only a few samples seem to be the culprits of the program crashing, and so I think it’d be convenient if there were a way to target them.

So far I’ve tried adding in try-except blocks as such:

try:
    some_variable = some_embedding_function(x)
except RuntimeError as e:
    print(e)
    import pdb; pdb.set_trace()

This returns the PDB interactive shell, but it also doesn’t allow me to access any of the problematic variables (it returns the error message every time I’d try to use, say, x). I’ve also tried printing/saving the value but it seems like the device-side error gets triggered first and shuts down the program.

Would anybody know if there’s a way to pinpoint the problematic samples? Thanks.

Device asserts are asynchronous, so you will typically see the backtrace in a later operation. By forcing synchronous operations, you can pin down the exact function.
As device asserts invalidate the GPU context, you need to move the data you want to inspect after them to CPU before calling the failing function. That CPU copy can then be accessed.

Best regards

Thomas

1 Like