Dataloader fails when called with `num_workers > 1`

Hi,

I’m a newcomer to pytorch.

I am trying to replicate a scientific ML model in pytorch which basically has two sub-networks with an inner product in the end. The TL;DR version is two feed-forward networks with ReLU activations in all but the last layer, an inner product and standard MSE loss.

It usually requires large paired input-output samples. I am attaching my code below, with two objectives:

  • Speeding the overall training somehow… A jax version of the same code is quite fast but terrible to read for a newcomer.
  • Getting to know why num_workers>1 case fails.
import torch
from torch import nn
from torch.utils import data
import pickle, random
import numpy as np
from tqdm.notebook import tqdm, trange

# Branch Net

class BranchNet(nn.Module):
    def __init__(self, num_branch_inputs, width, depth):
        super(BranchNet, self).__init__()
        self.num_branch_inputs = num_branch_inputs
        self.width = width
        self.depth = depth
        self.branch = nn.Sequential(
            nn.Linear(num_branch_inputs, width),
            nn.ReLU(),
            nn.Linear(width, width),
            nn.ReLU(),
            nn.Linear(width, width)
        )

    def _weight_init(self):
        for layer in self.branch:
            if isinstance(layer, nn.Linear):
                nn.init.xavier_normal_(layer.weight)
                if layer.bias is not None:
                    nn.init.constant_(layer.bias, 0.)
    
    def forward(self, x):
        return self.branch(x)

## Trunk Net

class TrunkNet(nn.Module):
    def __init__(self, num_trunk_inputs, width, depth):
        
        super(TrunkNet, self).__init__()
        self.num_branch_inputs = num_trunk_inputs
        self.width = width
        self.depth = depth
        self.trunk = nn.Sequential(
            nn.Linear(num_trunk_inputs, width),
            nn.ReLU(),
            nn.Linear(width, width),
            nn.ReLU(),
            nn.Linear(width, width)
        )

    def _weight_init(self):
        for layer in self.trunk:
            if isinstance(layer, nn.Linear):
                nn.init.xavier_normal_(layer.weight)
                if layer.bias is not None:
                    nn.init.constant_(layer.bias, 0.)

    def forward(self, x):
        return self.trunk(x)

## Total 

class DeepONet:
    def __init__(
        self, branch_width, branch_depth, trunk_width, trunk_depth, num_branch_inputs, num_trunk_inputs
    ):
        self._dev = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        super(DeepONet, self).__init__()
        self.branch_net = BranchNet(num_branch_inputs, branch_width, branch_depth)
        self.trunk_net = TrunkNet(num_trunk_inputs, trunk_width, trunk_depth)
        self.branch_net._weight_init()
        self.trunk_net._weight_init()
        self.branch_net.to(self._dev)
        self.trunk_net.to(self._dev)
        self.loss_log = []
    
    def get_loss(self, inputs, outputs):
        
        # inputs = inputs.to(self._dev)
        # outputs = outputs.to(self._dev)
        
        u, y = inputs
        pred_branch = self.branch_net(u).to(self._dev)
        pred_trunk = self.trunk_net(y).to(self._dev)
        pred = torch.sum(pred_branch * pred_trunk) # output should be batch_size x 1 ?
        loss = torch.nn.functional.mse_loss(pred.flatten(), outputs.flatten())
        return loss
    
    def train(self, num_epochs, lr, u_train, y_train, s_train):
        dataset = DataGenerator(u_train, y_train, s_train, self._dev)
        dataloader = data.DataLoader(dataset, batch_size=10000, shuffle=False, num_workers=8)   ### THE CODE WORKS if `num_workers=0`
        
        optimizer = torch.optim.Adam([*self.branch_net.parameters(), *self.trunk_net.parameters()], lr=lr)
        
        for epoch in tqdm(range(num_epochs)):
            for i, (inputs, outputs) in tqdm(enumerate(dataloader)):
                loss = self.get_loss(inputs, outputs)
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                
                if epoch % 100 == 0:
                    print(f'Epoch {epoch}: {loss.item()}')
                    self.loss_log.append(loss.item())
        return self.loss_log
    
    def save(self, path):
        torch.save(self.branch_net.state_dict(), path+"branch_net.pth")
        torch.save(self.trunk_net.state_dict(), path+"trunk_net.pth")
        with open(path+"loss_log.pkl", "wb") as f:
            pickle.dump(self.loss_log, f)
    
    def load(self, path):
        self.branch_net.load_state_dict(torch.load(path+"branch_net.pth"))
        self.trunk_net.load_state_dict(torch.load(path+"trunk_net.pth"))
        with open(path+"loss_log.pkl", "rb") as f:
            self.loss_log = pickle.load(f)
        
        self.branch_net.eval()
        self.trunk_net.eval()
    
    def predict(self, inputs, outputs):
        u, y = inputs
        pred_branch = self.branch_net(u)
        pred_trunk = self.trunk_net(y)
        pred = torch.sum(pred_branch * pred_trunk) # output should be batch_size x 1 ?
        loss = torch.nn.functional.mse_loss(pred.flatten(), outputs.flatten())
        return loss.item()

## Data Generator

class DataGenerator(data.Dataset):
    def __init__(self, u, y, s, dev):
        'Initialization'
        self.u = u
        self.y = y
        self.s = s
        self._dev = dev
        
        self.N = u.shape[0]

    def __getitem__(self, index):
        'Generate one batch of data'
        inputs = (self.u[index, :].to(self._dev), self.y[index, :].to(self._dev))
        outputs = self.s[index, :].to(self._dev)
        return inputs, outputs
    
    def __len__(self):
        'Denotes the number of batches per epoch'
        return self.N

## Load the data

# =========== CONSIDER FOR THE TIME BEING RANDOM DATA  ===========

u_train = torch.rand(676000, 26)
y_train = torch.rand(676000, 2)
s_train = torch.rand(676000, 1)



# u_train = torch.from_numpy(np.load('u_train.npz')['u_train'])
# y_train = torch.from_numpy(np.load('y_train.npz')['y_train'])
# s_train = torch.from_numpy(np.load('s_train.npz')['s_train'])

deep_o_net = DeepONet(branch_width=10, branch_depth=4, trunk_width=10, trunk_depth=4, num_branch_inputs=u_train.shape[1], num_trunk_inputs=y_train.shape[1])


# ************************ Can I somehow speed up the training ? **********************************
deep_o_net.train(num_epochs=1000, lr=0.0001, u_train=u_train, y_train=y_train, s_train=s_train)  # extremely slow

You want to wrap most of your main script within the if __name__ == '__main__': block, that will prevent the error by making sure those statements don’t run again when each worker process is launched. You can find more details here.

if __name__ == '__main__':

    u_train = torch.rand(676000, 26)
    y_train = torch.rand(676000, 2)
    s_train = torch.rand(676000, 1)

    # u_train = torch.from_numpy(np.load('u_train.npz')['u_train'])
    # y_train = torch.from_numpy(np.load('y_train.npz')['y_train'])
    # s_train = torch.from_numpy(np.load('s_train.npz')['s_train'])

    deep_o_net = DeepONet(branch_width=10, branch_depth=4, trunk_width=10, trunk_depth=4,
                          num_branch_inputs=u_train.shape[1], num_trunk_inputs=y_train.shape[1])

    deep_o_net.train(num_epochs=1000, lr=0.0001, u_train=u_train, y_train=y_train, s_train=s_train)
1 Like

Thanks, I did try wrapping up in a if __name__ == "__main__": block, however that results in the following error:

RuntimeError: cuda runtime error (801) : operation not supported at ..\torch/csrc/generic/StorageSharing.cpp:249
Traceback (most recent call last):
  File "C:\Miniconda3\envs\pytorch\lib\multiprocessing\queues.py", line 245, in _feed
    obj = _ForkingPickler.dumps(obj)
  File "C:\Miniconda3\envs\pytorch\lib\multiprocessing\reduction.py", line 51, in dumps
    cls(buf, protocol).dump(obj)
  File "C:\Miniconda3\envs\pytorch\lib\site-packages\torch\multiprocessing\reductions.py", line 247, in reduce_tensor
    event_sync_required) = storage._share_cuda_()
RuntimeError: cuda runtime error (801) : operation not supported at ..\torch/csrc/generic/StorageSharing.cpp:249

I am running torch on Windows in a conda env with the following versions installed

# packages in environment at C:\Miniconda3\envs\pytorch:
#
# Name                    Version                   Build  Channel
pytorch                   1.10.1          py3.9_cuda11.3_cudnn8_0    pytorch
pytorch-mutex             1.0                        cuda    pytorch  

if that helps.

Thanks again for your help!

Edit: I just happen to have stumbled across this which seems to suggest that pytorch multiprocessing doesn’t work on windows…

For Windows, two alternatives are suggested here.

1 Like

Thanks, Kevin!

Unfortunately setting num_workers to 0 would slow things down even further. I am not sure about the second suggestion; will check it out.