Non-deterministic NaN Error in "SlowConv3DBackward" with seed

Hi,

I after struggeling for a day im stuck, any ideas would be greatly appriciated!
I’m trying to do a 3D convolution on a synthetic MNIST dataset where the images are stretched in a third dimension with torch.einsum('i,jkl->jikl', torch.ones(28), img) s.t. the dimension become (1,1,28,28,28).

The following code in VSCode gives me a roughly 20% chance of nan value in the lss.backward() calculation over a single image.

import numpy as np  
import torch
from torch import nn, optim
import torchvision
import random

torch.cuda.manual_seed(5)
torch.manual_seed(5)
np.random.seed(42)
random.seed(42)


class AddDimension(torch.nn.Module):
    
    def __init__(self):
        super(AddDimension, self).__init__()
        
    def forward(self, img):
        return torch.einsum('i,jkl->jikl', torch.ones(28), img)
        


transform = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
    AddDimension()
])

train_dataset = torchvision.datasets.MNIST(
    root="~/torch_datasets", train=True, transform=transform, download=True
)

test_dataset = torchvision.datasets.MNIST(
    root="~/torch_datasets", train=False, transform=transform, download=True
)

train_loader = torch.utils.data.DataLoader(
    train_dataset, batch_size=1, shuffle=True, num_workers=0, pin_memory=True
)

test_loader = torch.utils.data.DataLoader(
    test_dataset, batch_size=1, shuffle=False, num_workers=4
)



class ConvAutoEncoder4(torch.nn.Module):
    def __init__(self):
        super(ConvAutoEncoder4, self).__init__()
        channels = 4
        self.fcn3 = nn.Conv3d(in_channels=1, out_channels=channels, kernel_size=4, stride=4)
        self.fcn4 = nn.Conv3d(in_channels=channels, out_channels=channels, kernel_size=1)
        self.fcn5 = nn.Conv3d(in_channels=channels, out_channels=channels, kernel_size=1)
        self.fcn6 = nn.ConvTranspose3d(in_channels=channels, out_channels=1, kernel_size=4, stride=4)
        self.print = False
        
    def set_print(self, pr):
        self.print = pr
        
    def forward(self, tensor):
        if not self.print:
            tensor = self.fcn3(tensor)
            tensor = torch.relu(tensor)
            tensor = self.fcn4(tensor)
            tensor = torch.relu(tensor)
            tensor = self.fcn5(tensor)
            tensor = torch.relu(tensor)
            tensor = self.fcn6(tensor)
            tensor = torch.relu(tensor)
            return tensor
        
        tensor = self.fcn3(tensor)
        tensor = torch.relu(tensor)
        print(tensor.shape)
#         for slc in tensor[0][0]:
#             print(slc)
#         for name, param in self.named_parameters():
#             print(name, torch.isfinite(param.grad).all())
        tensor = self.fcn4(tensor) 
        tensor = torch.relu(tensor)
        tensor = self.fcn5(tensor)
        tensor = torch.relu(tensor)
        tensor = self.fcn6(tensor)
        tensor = torch.relu(tensor)
        for name, param in self.named_parameters():
            print(name, param.grad)
        return tensor


it = iter(train_loader)
imgs = [next(it)[0] for i in range(100)]

print(len(imgs))

print(imgs[0].shape)

mod = ConvAutoEncoder4()
crtrn = nn.MSELoss()
ptmzr = optim.Adam(mod.parameters(), lr=1e-3)

torch.autograd.set_detect_anomaly(True)

mod.set_print(True)

ptmzr.zero_grad()
res = mod(imgs[0])
lss = crtrn(res, imgs[0])
lss.backward()
ptmzr.step()
print(lss.item())

With the following stacktrace:

  File "/Users/.../python/MNIST3D.py", line 118, in <module>
    res = mod(imgs[0])
  File "/usr/local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/Users/.../python/MNIST3D.py", line 93, in forward
    tensor = self.fcn5(tensor)
  File "/usr/local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/usr/local/lib/python3.8/site-packages/torch/nn/modules/conv.py", line 572, in forward
    return F.conv3d(input, self.weight, self.bias, self.stride,
 (Triggered internally at  ../torch/csrc/autograd/python_anomaly_mode.cpp:104.)
  Variable._execution_engine.run_backward(
Traceback (most recent call last):
  File "/Users/.../python/MNIST3D.py", line 120, in <module>
    lss.backward()
  File "/usr/local/lib/python3.8/site-packages/torch/tensor.py", line 221, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph)
  File "/usr/local/lib/python3.8/site-packages/torch/autograd/__init__.py", line 130, in backward
    Variable._execution_engine.run_backward(
RuntimeError: Function 'SlowConv3DBackward' returned nan values in its 1th output.

Any clues?