Hi,
I am still trying to resolve this issue without success. To make things easier, I have attached an example code at the bottom which will demonstrate the problem (please note the dependency on my library). Unfortunately, even with setting the seed and CUDNN to deterministic, there is no way to exactly reproduce after how many iterations the error will appear. Sometimes it takes 2000 updates, sometimes it will only break after 30000 or more updates (as below). Please note that in my example, I am always using the same batch. Again, I have confirmed that torch.set_anomaly_enabled(True) will make the issue disappear.
gdb gives me the following backtrace:
Iteration 49800 : loss = 0.6797990798950195 accuracy = 0.5477447509765625
Iteration 49900 : loss = 0.6751381754875183 accuracy = 0.5528717041015625
terminate called after throwing an instance of ‘std::out_of_range’
what(): vector::_M_range_check: __n (which is 18446744073531805265) >= this->size() (which is 3)
Thread 8 “python” received signal SIGABRT, Aborted.
[Switching to Thread 0x1554f567f700 (LWP 8711)]
__GI_raise (sig=sig@entry=6) at …/sysdeps/unix/sysv/linux/raise.c:51
51 …/sysdeps/unix/sysv/linux/raise.c: No such file or directory.
(gdb) bt
#0 __GI_raise (sig=sig@entry=6) at …/sysdeps/unix/sysv/linux/raise.c:51
#1 0x0000155554f7b801 in __GI_abort () at abort.c:79
#2 0x0000155546032957 in ?? () from /usr/lib/x86_64-linux-gnu/libstdc++.so.6
#3 0x0000155546038ab6 in ?? () from /usr/lib/x86_64-linux-gnu/libstdc++.so.6
#4 0x0000155546038af1 in std::terminate() () from /usr/lib/x86_64-linux-gnu/libstdc++.so.6
#5 0x00001555472e917e in execute_native_thread_routine () from $VENVPATH/lib/python3.6/site-packages/torch/_C.cpython-36m-x86_64-linux-gnu.so
#6 0x0000155554d236db in start_thread (arg=0x1554f567f700) at pthread_create.c:463
#7 0x000015555505c88f in clone () at …/sysdeps/unix/sysv/linux/x86_64/clone.S:95
I don’t quite know what to do with this backtrace myself, but hopefully you could make some sense out of it. Do you have any suggestion how I could proceed from here?
import torch
from irim import InvertibleUnet
from irim import MemoryFreeInvertibleModule
torch.manual_seed(0)
torch.backends.cudnn.enabled = True
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
# Use CUDA if devices are available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# ---- Parameters ---
# Working with images, for time series or volumes set to 1 or 3, respectively
conv_nd = 2
# Number of Householder projections for constructing 1x1 convolutions
n_householder = 3
# Number of channels for each layer of the Invertible Unet
n_channels = [3]
# Number of hidden channel in the residual functions of the Invertible Unet
n_hidden = [16]
# Downsampling factors
dilations = [1]
# Number of IRIM steps
n_steps = 1
# Number of image channels
im_channels = 3
# Number of total samples
n_samples = 64
im_size = 32
learning_rate = 1e-3
# Construct Invertible Unet
model = InvertibleUnet(n_channels=n_channels,n_hidden=n_hidden,dilations=dilations,
conv_nd=conv_nd, n_householder=n_householder)
# Wrap the model for Invert to Learn
model = MemoryFreeInvertibleModule(model)
# Move model to CUDA device if possible
model.to(device)
# Use data parallel if possible
#if torch.cuda.device_count() > 1:
# model = torch.nn.DataParallel(model)
optimizer = torch.optim.Adam(model.parameters(), learning_rate)
# Input data drawn form a standard normal
x_in = torch.randn(n_samples,im_channels,*[im_size]*conv_nd, device=device)
# Binary labels for each sample
y_in = torch.empty(n_samples,1,*[im_size]*conv_nd, device=device).random_(2)
torch.set_anomaly_enabled(False)
for i in range(300000):
optimizer.zero_grad()
model.zero_grad()
# Forward computation
y_est = model.forward(x_in)
# We use the first channel for prediction
y_est = y_est[:,:1]
loss = torch.nn.functional.binary_cross_entropy_with_logits(y_est, y_in)
loss.backward()
optimizer.step()
if i % 100 == 0:
y_est = (y_est >= 0.).float()
accuracy = torch.mean((y_est == y_in).float())
print('Iteration', i, ': loss =',loss.item(), 'accuracy = ', accuracy.item())