Handling OOM using torch.no_grad() throws import errors

I am using a torchtext dataloader and during inference, while using the following,

with torch.no_grad():
    try:
        # inference code
    except RuntimeError:
        torch.cuda.empty_cache()
        continue

When a OOM error is encountered and tries to continue, python throws import errors.
Without using torch.no_grad(), the OOM is handled and continues to execute without any errors.
I am not able to understand this behaviour.

Hi

Could you give the full stack trace of the error please? Which import fails and where?

My bad, while handling the error, I called the error message as re. Iā€™m also using the regex library re, which resulted in an import error. Please ignore.

1 Like