Hyperparameter training, how to break out of dataloader iterator cleanly

Hi yall, Ive been doing hyper-parameter training and in my case found that the cost of evaluating various learning statistics and breaking epochs early was fine. But when my hyperparameterizer was doing that I found errors leading to multiprocessing.py - I presume that my function started a new process before the old one was finished or something of the sort.

I looked at dataloader.py and did not find a method that joins the queues to prevent this error. So I just hacked a time.sleep(arbitrary_n_seconds) and then call trainloader.iter()._shutdown_workers()

So, question, is there such a method built into DataLoader or shouldnt there be one? It seems like simple thread management.

pytorch 0.4.0 from conda
python 3.6.5
ubuntu 16.04
no difference if device is cpu or gpu

1 Like

Could you show some code how you start new processes?

In general it should be fine to simply call break inside the loop iterating through the dataloader.

sure. Here I built a completely contrived example using cifar. It doesnt throw an error every time, so you may have to run it a few times before you see the error; Put it in a jupyter notebook and go.
Only thing it does that you may want to change is the path of the data folder - ~/data/ is my go to location for databases

My training example is more complex, I am running dozens of tests, I load different set of hyperparams every time, and instead of initializing the net anew every time I’m loading various saved .pths
It seems that the more complex the initialization, the more likely the error ( but thats just circumstantial evidence and i don’t see why it would make sense)

but as I said, if i let it sleep for a bit just before breaking the error goes away, so my hunch is that shutdown workers could use a join process.
running this on cpu also can cause these errors.


import os.path as osp
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 120)
        self.fc3 = nn.Linear(120, 84)
        self.fc4 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        x = self.fc4(x)
        return x
def init_net(net, method="xavier", bias=0.0):
    """ initialize net according to standard init
            net: torch.nn module
            method: default: 'xavier', also 'kaiming', 'sparse', 'orthogonal'
            bias: default:0
            verbose: default: false

    init_ = {"xavier":nn.init.xavier_normal_,

    for name, module in net.named_modules():
        if 'weight' in module._parameters.keys():

            if module.bias.data is not None and method != 2:

CIFARROOT = osp.join(osp.expanduser('~'), 'data/CIFAR')
if not osp.exists(CIFARROOT):
assert(osp.exists(CIFARROOT)), 'path does not exist'

transform = transforms.Compose([transforms.ToTensor(),
                                        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

trainset = datasets.CIFAR10(root=CIFARROOT, train=True, download=True, transform=transform)
testset = datasets.CIFAR10(root=CIFARROOT, train=False, download=True, transform=transform)
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

net = Net()

criterion = nn.CrossEntropyLoss()
epochs = 5

tests = ( 0.01, 0.001, 0.0001, 0.00001)
bs = 100
_device = 'cuda'
for lr in tests:
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=bs, shuffle=False, num_workers=8)
    optimizer = optim.SGD(net.parameters(), lr=lr, momentum=0.9) 
    for epoch in range(epochs):
        for i, data in enumerate(trainloader, 0):

            inputs, labels = data
            inputs, labels = inputs.to(_device), labels.to(_device)
            outputs = net(inputs)
            loss = criterion(outputs, labels)
            if i == 3: 


Exception ignored in: <bound method _DataLoaderIter.__del__ of <torch.utils.data.dataloader._DataLoaderIter object at 0x7f9ef48936d8>>
Traceback (most recent call last):
  File "/home/z/miniconda3/envs/abb/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 349, in __del__
  File "/home/z/miniconda3/envs/abb/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 328, in _shutdown_workers
  File "/home/z/miniconda3/envs/abb/lib/python3.6/multiprocessing/queues.py", line 337, in get
    return _ForkingPickler.loads(res)
  File "/home/z/miniconda3/envs/abb/lib/python3.6/site-packages/torch/multiprocessing/reductions.py", line 70, in rebuild_storage_fd
    fd = df.detach()
  File "/home/z/miniconda3/envs/abb/lib/python3.6/multiprocessing/resource_sharer.py", line 57, in detach
    with _resource_sharer.get_connection(self._id) as conn:
  File "/home/z/miniconda3/envs/abb/lib/python3.6/multiprocessing/resource_sharer.py", line 87, in get_connection
    c = Client(address, authkey=process.current_process().authkey)
  File "/home/z/miniconda3/envs/abb/lib/python3.6/multiprocessing/connection.py", line 493, in Client
    answer_challenge(c, authkey)
  File "/home/z/miniconda3/envs/abb/lib/python3.6/multiprocessing/connection.py", line 737, in answer_challenge
    response = connection.recv_bytes(256)        # reject large message
  File "/home/z/miniconda3/envs/abb/lib/python3.6/multiprocessing/connection.py", line 216, in recv_bytes
    buf = self._recv_bytes(maxlength)
  File "/home/z/miniconda3/envs/abb/lib/python3.6/multiprocessing/connection.py", line 407, in _recv_bytes
    buf = self._recv(4)
  File "/home/z/miniconda3/envs/abb/lib/python3.6/multiprocessing/connection.py", line 379, in _recv
    chunk = read(handle, remaining)
ConnectionResetError: [Errno 104] Connection reset by peer

1 Like