Hi, I think I managed to make a reproducible example:
from itertools import count
import torch
import torch.nn as nn
from torch import Tensor
from torch.optim import SGD
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
import numpy as np
INPUT_SIZE = 100
TARGET_SIZE = 5
BATCH_SIZE = 32
class ZeroModel(nn.Module):
def __init__(self, out_size: int):
super().__init__()
self._par = nn.Parameter(torch.ones(1, dtype=torch.float), requires_grad=True)
self._out_size = out_size
def forward(self, batch: Tensor) -> Tensor:
batch_size = batch.size(0)
zeros = torch.zeros(batch_size, self._out_size, dtype=batch.dtype, device=batch.device, requires_grad=True)
return zeros
class BatchDataset(Dataset):
def __init__(self, input_size, target_size):
self._input_size = input_size
self.target_size = target_size
def __len__(self):
return BATCH_SIZE * 10000
def __getitem__(self, index):
return np.random.uniform(0, 1, size=self._input_size), np.random.uniform(0, 1, size=self.target_size)
dataset = BatchDataset(INPUT_SIZE, TARGET_SIZE)
loader = DataLoader(dataset, batch_size=BATCH_SIZE, num_workers=4, shuffle=True, pin_memory=True, drop_last=True)
model = nn.DataParallel(ZeroModel(TARGET_SIZE).cuda())
optim = SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)
criterion = nn.MSELoss()
batch_iter = iter(loader)
for i in tqdm(count()):
try:
batch, targets = next(batch_iter)
except StopIteration:
batch_iter = iter(loader)
batch, targets = next(batch_iter)
batch, targets = batch.cuda(), targets.cuda()
output = model(batch)
loss = criterion(output, targets)
optim.zero_grad()
loss.backward()
optim.step()
With this code the error occurs on my machine around 5k-20k iterations.
I think there is a race condition since it occurs on different iteration each time.
Also, it seems it that there is some kind of interaction with other modules, since I can’t reproduce it with just torch
installed.
I’m using python installed from Anaconda
Here is some more info:
$ uname -a
Linux raven 5.0.0-37-generic #40~18.04.1-Ubuntu SMP Thu Nov 14 12:06:39 UTC 2019 x86_64 x86_64 x86_64 GNU/Linux
$ python --version
Python 3.7.5
$ nvidia-smi --query-gpu=name,driver_version --format=csv
name, driver_version
GeForce RTX 2080 Ti, 435.21
GeForce RTX 2080 Ti, 435.21
$ pip freeze
absl-py==0.8.1
attrs==19.3.0
backcall==0.1.0
cachetools==3.1.1
certifi==2019.11.28
chardet==3.0.4
cycler==0.10.0
decorator==4.4.1
google-auth==1.8.2
google-auth-oauthlib==0.4.1
grpcio==1.25.0
idna==2.8
importlib-metadata==1.3.0
ipython==7.10.2
ipython-genutils==0.2.0
jedi==0.15.1
jsonschema==3.2.0
jupyter-core==4.6.1
kiwisolver==1.1.0
Markdown==3.1.1
matplotlib==3.1.2
more-itertools==8.0.2
mpl-finance==0.10.0
multitasking==0.0.9
nbformat==4.4.0
numpy==1.17.4
oauthlib==3.1.0
pandas==0.25.3
parso==0.5.2
pexpect==4.7.0
pickleshare==0.7.5
plotly==4.4.1
prompt-toolkit==3.0.2
protobuf==3.11.1
ptyprocess==0.6.0
pyasn1==0.4.8
pyasn1-modules==0.2.7
Pygments==2.5.2
pyparsing==2.4.5
pyrsistent==0.15.6
python-dateutil==2.8.1
pytz==2019.3
PyYAML==5.2
requests==2.22.0
requests-oauthlib==1.3.0
retrying==1.3.3
rsa==4.0
six==1.13.0
tensorboard==2.1.0
torch==1.3.1
tqdm==4.40.2
traitlets==4.3.3
urllib3==1.25.7
wcwidth==0.1.7
Werkzeug==0.16.0
yacs==0.1.6
yfinance==0.1.52
zipp==0.6.0