Unable to find a valid cuDNN algorithm during autograd

I keep getting this error when trying to compute the grads of a non-leaf tensor during training.

RuntimeError                              Traceback (most recent call last)
<ipython-input-6-5a671eebd60a> in <module>
     18 #         group['lr'] = newlr
     19 
<ipython-input-5-80f6a1f0d5ab> in train(model, opt, data_loader, criterion, device)
     19             register_nonleaf_grads(logits)
     20         crossent = criterion(logits, y)
---> 21         crossent.backward(retain_graph=True)

~/miniconda3/envs/tflow/lib/python3.7/site-packages/torch/tensor.py in backward(self, gradient, retain_graph, create_graph, inputs)
    243                 create_graph=create_graph,
    244                 inputs=inputs)
--> 245         torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
    246 
    247     def register_hook(self, hook):

~/miniconda3/envs/tflow/lib/python3.7/site-packages/torch/autograd/__init__.py in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)
    145     Variable._execution_engine.run_backward(
    146         tensors, grad_tensors_, retain_graph, create_graph, inputs,
--> 147         allow_unreachable=True, accumulate_grad=True)  # allow_unreachable flag
    148 
    149 

RuntimeError: Unable to find a valid cuDNN algorithm to run convolution

Here’s my approach, any ideas what I’m doing wrong?

def register_nonleaf_grads(variable):
    def hook(grad):
        variable.nonleaf_grads = grad
    variable.register_hook(hook)

def train(model, opt, data, criterion):
    for x, y in data:
        x.requires_grad_()
        logits = model(x)
        register_nonleaf_grads(logits)
        crossent = criterion(logits, y)
        crossent.backward(retain_graph=True)
        loss = crossent + l2_norm
        x.grad.zero_()
        logits.nonleaf_grads.zero_()
        opt.zero_grad()
        loss.backward()
        opt.step()

I don’t know which PyTorch, CUDA, cudnn versions and GPU you are using, but since no cudnn algorithm can be found for your setup, you could either disable cudnn globally via torch.backends.cudnn.enabled = False or via with torch.backends.cudnn.flags(enabled=False).

1.8.1, 11.2, 1080

The thing that bothers me is that without the register_non_leaf_grads everything works fine no complain about not being able to find cuDNN algo.

Also if I replace the register_nonleaf_grads for logits.retain_grad() things also work without any errors, which leads me to believe that there might be sth wrong with either register_nonleaf_grads or internally in pytorch? (I could be wrong though)