Hi,
I after struggeling for a day im stuck, any ideas would be greatly appriciated!
I’m trying to do a 3D convolution on a synthetic MNIST dataset where the images are stretched in a third dimension with torch.einsum('i,jkl->jikl', torch.ones(28), img)
s.t. the dimension become (1,1,28,28,28).
The following code in VSCode gives me a roughly 20% chance of nan
value in the lss.backward()
calculation over a single image.
import numpy as np
import torch
from torch import nn, optim
import torchvision
import random
torch.cuda.manual_seed(5)
torch.manual_seed(5)
np.random.seed(42)
random.seed(42)
class AddDimension(torch.nn.Module):
def __init__(self):
super(AddDimension, self).__init__()
def forward(self, img):
return torch.einsum('i,jkl->jikl', torch.ones(28), img)
transform = torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
AddDimension()
])
train_dataset = torchvision.datasets.MNIST(
root="~/torch_datasets", train=True, transform=transform, download=True
)
test_dataset = torchvision.datasets.MNIST(
root="~/torch_datasets", train=False, transform=transform, download=True
)
train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=1, shuffle=True, num_workers=0, pin_memory=True
)
test_loader = torch.utils.data.DataLoader(
test_dataset, batch_size=1, shuffle=False, num_workers=4
)
class ConvAutoEncoder4(torch.nn.Module):
def __init__(self):
super(ConvAutoEncoder4, self).__init__()
channels = 4
self.fcn3 = nn.Conv3d(in_channels=1, out_channels=channels, kernel_size=4, stride=4)
self.fcn4 = nn.Conv3d(in_channels=channels, out_channels=channels, kernel_size=1)
self.fcn5 = nn.Conv3d(in_channels=channels, out_channels=channels, kernel_size=1)
self.fcn6 = nn.ConvTranspose3d(in_channels=channels, out_channels=1, kernel_size=4, stride=4)
self.print = False
def set_print(self, pr):
self.print = pr
def forward(self, tensor):
if not self.print:
tensor = self.fcn3(tensor)
tensor = torch.relu(tensor)
tensor = self.fcn4(tensor)
tensor = torch.relu(tensor)
tensor = self.fcn5(tensor)
tensor = torch.relu(tensor)
tensor = self.fcn6(tensor)
tensor = torch.relu(tensor)
return tensor
tensor = self.fcn3(tensor)
tensor = torch.relu(tensor)
print(tensor.shape)
# for slc in tensor[0][0]:
# print(slc)
# for name, param in self.named_parameters():
# print(name, torch.isfinite(param.grad).all())
tensor = self.fcn4(tensor)
tensor = torch.relu(tensor)
tensor = self.fcn5(tensor)
tensor = torch.relu(tensor)
tensor = self.fcn6(tensor)
tensor = torch.relu(tensor)
for name, param in self.named_parameters():
print(name, param.grad)
return tensor
it = iter(train_loader)
imgs = [next(it)[0] for i in range(100)]
print(len(imgs))
print(imgs[0].shape)
mod = ConvAutoEncoder4()
crtrn = nn.MSELoss()
ptmzr = optim.Adam(mod.parameters(), lr=1e-3)
torch.autograd.set_detect_anomaly(True)
mod.set_print(True)
ptmzr.zero_grad()
res = mod(imgs[0])
lss = crtrn(res, imgs[0])
lss.backward()
ptmzr.step()
print(lss.item())
With the following stacktrace:
File "/Users/.../python/MNIST3D.py", line 118, in <module>
res = mod(imgs[0])
File "/usr/local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
result = self.forward(*input, **kwargs)
File "/Users/.../python/MNIST3D.py", line 93, in forward
tensor = self.fcn5(tensor)
File "/usr/local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
result = self.forward(*input, **kwargs)
File "/usr/local/lib/python3.8/site-packages/torch/nn/modules/conv.py", line 572, in forward
return F.conv3d(input, self.weight, self.bias, self.stride,
(Triggered internally at ../torch/csrc/autograd/python_anomaly_mode.cpp:104.)
Variable._execution_engine.run_backward(
Traceback (most recent call last):
File "/Users/.../python/MNIST3D.py", line 120, in <module>
lss.backward()
File "/usr/local/lib/python3.8/site-packages/torch/tensor.py", line 221, in backward
torch.autograd.backward(self, gradient, retain_graph, create_graph)
File "/usr/local/lib/python3.8/site-packages/torch/autograd/__init__.py", line 130, in backward
Variable._execution_engine.run_backward(
RuntimeError: Function 'SlowConv3DBackward' returned nan values in its 1th output.
Any clues?