Subprocess terminated without any exception when passing the multidimensional tensor to the model

I set the model hidden layer a litter bigger(like 256), and use load_state_dict in the subprocess.
When I give a multi-dimensional tensor to the model, the process will be terminated without any exception.

I write a demo to reproduce the error.

import torch.multiprocessing as mp
import torch
import torch.nn as nn


class AC(nn.Module):
    def __init__(self, features_n, actions_n):
        super(AC, self).__init__()
        # if the hidden layer cells are lower, for example, 128, no error occurs.
        self.hidden_layer_cells = 256
        self.l1 = nn.Linear(features_n, self.hidden_layer_cells)
        self.l2 = nn.Linear(self.hidden_layer_cells, self.hidden_layer_cells)
        self.actor_linear = nn.Linear(self.hidden_layer_cells, actions_n)
        self.critic_linear = nn.Linear(self.hidden_layer_cells, 1)

    def forward(self, inputs):
        x = torch.tanh(self.l1(inputs))
        x = torch.tanh(self.l2(x))
        pi = self.actor_linear(x)
        q = self.critic_linear(x)
        return pi, q


class Worker2(mp.Process):
    def __init__(self) -> None:
        super(Worker2, self).__init__()
        self.t = AC(10, 10)
        self.tt = AC(10, 10)
        # if I load state dict from exist model, it will be terminated when passing Multidimensional tensor
        self.tt.load_state_dict(self.t.state_dict())

    def run(self):
        while True:
            s = torch.ones(size=(1, 10))
            a = self.t(s)
            ss = torch.cat((s, s))
            # this line will terminate the process
            aa = self.t(ss)


w = Worker2()
w.start()
w.join()

I am not able to reproduce your issue. Do you mind giving a bit more detail about your execution environment? Which version of PyTorch? Which operating system?

@cbalioglu I’m running it in a docker container.
Operation system:


Pytorch version:

When I run it on my MacBook, indeed, no problem occurs.

I tested it on a Fedora machine and although it does not get terminated, I does get stuck on aa = self.t(ss) line.

The problem is that you are calling self.tt.load_state_dict() in the parent process and then use self.tt in the forked process. There are known issues in “implicitly” passing tensors via fork to child processes. If you move the logic of __init__() into run(), you will mitigate the issue.

Overall my advice would be to avoid forks in all circumstances. Make sure to spawn your child processes so that they have their own clean state. As you have already experienced forking can be very fragile and unpredictable (and historically it long precedes multi-threading and was never meant to be used with multithreaded processes).

2 Likes