Getting NaN in layer output with custom activation function

I implemented a custom activation function that appears to occasionally cause NaNs in the output. I haven’t been able to ascertain how. Currently I’m debugging the network with a check for NaN in the output that I hope will allow me to reproduce this more reliably, but I wanted to post my function in case I’m doing something inherently stupid. My function is below. I’ve been using it to replace ReLU layers in VGG11 wholesale.

class Renlu(nn.Module):
    """
    Pytorch nn module implementation of "renlu" activation function
    where renlu(x, alpha) = 0 if x <=0 else x^alpha
    """
    
    def __init__(self, alpha=0.5):
        
        super(Renlu, self).__init__()
        self.alpha = float(alpha)
    
    def forward(self, input_tensor):
        
        output = torch.relu(input_tensor)

        idxs = output.nonzero(as_tuple=True)
        output[idxs] = output[idxs].pow(self.alpha)
        
        return output

Initially I was returning torch.relu(input_tensor) ** self.alpha, but after adding a call to torch.autograd.set_detect_anomaly(True) I realized it resulted in more frequent NaNs during backprop, and I wound up adding the nonzero() call so that I only apply the exponent to elements that don’t get rectified. However, I still see the error below on about 10% of the networks I train:

Warning: Traceback of forward call that caused the error:
  File "/allen/programs/braintv/workgroups/nc-ophys/briar.doty/Source/allen-inst-cell-types/train_net.py", line 53, in <module>
    main(**vars(args))
  File "/allen/programs/braintv/workgroups/nc-ophys/briar.doty/Source/allen-inst-cell-types/train_net.py", line 43, in main
    manager.run_training_loop(criterion, optimizer, exp_lr_scheduler,
  File "/allen/programs/braintv/workgroups/nc-ophys/briar.doty/Source/allen-inst-cell-types/modules/NetManager.py", line 538, in run_training_loop
    (train_acc, train_loss) = self.train_net(criterion, optimizer, scheduler, train_frac)
  File "/allen/programs/braintv/workgroups/nc-ophys/briar.doty/Source/allen-inst-cell-types/modules/NetManager.py", line 489, in train_net
    loss = criterion(outputs, labels)
  File "/allen/programs/braintv/workgroups/nc-ophys/briar.doty/anaconda3/envs/dlct2/lib/python3.8/site-packages/torch/nn/modules/module.py", line 532, in __call__
    result = self.forward(*input, **kwargs)
  File "/allen/programs/braintv/workgroups/nc-ophys/briar.doty/anaconda3/envs/dlct2/lib/python3.8/site-packages/torch/nn/modules/loss.py", line 915, in forward
    return F.cross_entropy(input, target, weight=self.weight,
  File "/allen/programs/braintv/workgroups/nc-ophys/briar.doty/anaconda3/envs/dlct2/lib/python3.8/site-packages/torch/nn/functional.py", line 2021, in cross_entropy
    return nll_loss(log_softmax(input, 1), target, weight, None, ignore_index, None, reduction)
  File "/allen/programs/braintv/workgroups/nc-ophys/briar.doty/anaconda3/envs/dlct2/lib/python3.8/site-packages/torch/nn/functional.py", line 1317, in log_softmax
    ret = input.log_softmax(dim)
 (print_stack at /opt/conda/conda-bld/pytorch_1579061855666/work/torch/csrc/autograd/python_anomaly_mode.cpp:57)
Traceback (most recent call last):
  File "/allen/programs/braintv/workgroups/nc-ophys/briar.doty/Source/allen-inst-cell-types/train_net.py", line 53, in <module>
    main(**vars(args))
  File "/allen/programs/braintv/workgroups/nc-ophys/briar.doty/Source/allen-inst-cell-types/train_net.py", line 43, in main
    manager.run_training_loop(criterion, optimizer, exp_lr_scheduler, 
  File "/allen/programs/braintv/workgroups/nc-ophys/briar.doty/Source/allen-inst-cell-types/modules/NetManager.py", line 538, in run_training_loop
    (train_acc, train_loss) = self.train_net(criterion, optimizer, scheduler, train_frac)
  File "/allen/programs/braintv/workgroups/nc-ophys/briar.doty/Source/allen-inst-cell-types/modules/NetManager.py", line 492, in train_net
    loss.backward()
  File "/allen/programs/braintv/workgroups/nc-ophys/briar.doty/anaconda3/envs/dlct2/lib/python3.8/site-packages/torch/tensor.py", line 195, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph)
  File "/allen/programs/braintv/workgroups/nc-ophys/briar.doty/anaconda3/envs/dlct2/lib/python3.8/site-packages/torch/autograd/__init__.py", line 97, in backward
    Variable._execution_engine.run_backward(
RuntimeError: Function 'LogSoftmaxBackward' returned nan values in its 0th output.

Any ideas?